diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d1f085b47..7ac5103d9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ on: env: SIGN_PIPE_VER: "v0.1.1" - GORELEASER_VER: "v2.3.2" + GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -169,6 +169,13 @@ jobs: - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu + - name: Decode GPG signing key + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso amd64 @@ -186,6 +193,24 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures + run: | + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*amd64*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc - name: Tag and push PR images (amd64 only) if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository run: | @@ -265,6 +290,13 @@ jobs: - name: Install dependencies run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 + - name: Decode GPG signing key + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install LLVM-MinGW for ARM64 cross-compilation run: | cd /tmp @@ -289,6 +321,24 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures + run: | + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index c0a5efbbe..0f81229cd 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -171,13 +171,12 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-deb + id: netbird_deb bindir: /usr/bin builds: - netbird formats: - deb - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" @@ -185,16 +184,18 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-rpm + id: netbird_rpm bindir: /usr/bin builds: - netbird formats: - rpm - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' dockers: - image_templates: - netbirdio/netbird:{{ .Version }}-amd64 @@ -876,7 +877,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -884,7 +885,7 @@ uploads: - name: yum ids: - - netbird-rpm + - netbird_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index a243702ea..470f1deaa 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -61,7 +61,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-deb + id: netbird_ui_deb package_name: netbird-ui builds: - netbird-ui @@ -80,7 +80,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-rpm + id: netbird_ui_rpm package_name: netbird-ui builds: - netbird-ui @@ -95,11 +95,14 @@ nfpms: dst: /usr/share/pixmaps/netbird.png dependencies: - netbird + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' uploads: - name: debian ids: - - netbird-ui-deb + - netbird_ui_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -107,7 +110,7 @@ uploads: - name: yum ids: - - netbird-ui-rpm + - netbird_ui_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com diff --git a/client/android/client.go b/client/android/client.go index ccf32a90c..3fc571559 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } @@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } diff --git a/client/cmd/expose.go b/client/cmd/expose.go index 991d3ab86..1334617d8 100644 --- a/client/cmd/expose.go +++ b/client/cmd/expose.go @@ -22,20 +22,24 @@ import ( var pinRegexp = regexp.MustCompile(`^\d{6}$`) var ( - exposePin string - exposePassword string - exposeUserGroups []string - exposeDomain string - exposeNamePrefix string - exposeProtocol string + exposePin string + exposePassword string + exposeUserGroups []string + exposeDomain string + exposeNamePrefix string + exposeProtocol string + exposeExternalPort uint16 ) var exposeCmd = &cobra.Command{ - Use: "expose ", - Short: "Expose a local port via the NetBird reverse proxy", - Args: cobra.ExactArgs(1), - Example: "netbird expose --with-password safe-pass 8080", - RunE: exposeFn, + Use: "expose ", + Short: "Expose a local port via the NetBird reverse proxy", + Args: cobra.ExactArgs(1), + Example: ` netbird expose --with-password safe-pass 8080 + netbird expose --protocol tcp 5432 + netbird expose --protocol tcp --with-external-port 5433 5432 + netbird expose --protocol tls --with-custom-domain tls.example.com 4443`, + RunE: exposeFn, } func init() { @@ -44,7 +48,52 @@ func init() { exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)") exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)") exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)") - exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)") + exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)") + exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)") +} + +// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags. +func isClusterProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp", "tls": + return true + default: + return false + } +} + +// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP) +// where domain display doesn't apply. TLS uses SNI so it has a domain. +func isPortBasedProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp": + return true + default: + return false + } +} + +// extractPort returns the port portion of a URL like "tcp://host:12345", or +// falls back to the given default formatted as a string. +func extractPort(serviceURL string, fallback uint16) string { + u := serviceURL + if idx := strings.Index(u, "://"); idx != -1 { + u = u[idx+3:] + } + if i := strings.LastIndex(u, ":"); i != -1 { + if p := u[i+1:]; p != "" { + return p + } + } + return strconv.FormatUint(uint64(fallback), 10) +} + +// resolveExternalPort returns the effective external port, defaulting to the target port. +func resolveExternalPort(targetPort uint64) uint16 { + if exposeExternalPort != 0 { + return exposeExternalPort + } + return uint16(targetPort) } func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { @@ -57,7 +106,15 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } if !isProtocolValid(exposeProtocol) { - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + } + + if isClusterProtocol(exposeProtocol) { + if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 { + return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol) + } + } else if cmd.Flags().Changed("with-external-port") { + return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol) } if exposePin != "" && !pinRegexp.MatchString(exposePin) { @@ -76,7 +133,12 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } func isProtocolValid(exposeProtocol string) bool { - return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https" + switch strings.ToLower(exposeProtocol) { + case "http", "https", "tcp", "udp", "tls": + return true + default: + return false + } } func exposeFn(cmd *cobra.Command, args []string) error { @@ -123,7 +185,7 @@ func exposeFn(cmd *cobra.Command, args []string) error { return err } - stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{ + req := &proto.ExposeServiceRequest{ Port: uint32(port), Protocol: protocol, Pin: exposePin, @@ -131,7 +193,12 @@ func exposeFn(cmd *cobra.Command, args []string) error { UserGroups: exposeUserGroups, Domain: exposeDomain, NamePrefix: exposeNamePrefix, - }) + } + if isClusterProtocol(exposeProtocol) { + req.ListenPort = uint32(resolveExternalPort(port)) + } + + stream, err := client.ExposeService(ctx, req) if err != nil { return fmt.Errorf("expose service: %w", err) } @@ -149,8 +216,14 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { return proto.ExposeProtocol_EXPOSE_HTTP, nil case "https": return proto.ExposeProtocol_EXPOSE_HTTPS, nil + case "tcp": + return proto.ExposeProtocol_EXPOSE_TCP, nil + case "udp": + return proto.ExposeProtocol_EXPOSE_UDP, nil + case "tls": + return proto.ExposeProtocol_EXPOSE_TLS, nil default: - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) } } @@ -160,20 +233,33 @@ func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServ return fmt.Errorf("receive expose event: %w", err) } - switch e := event.Event.(type) { - case *proto.ExposeServiceEvent_Ready: - cmd.Println("Service exposed successfully!") - cmd.Printf(" Name: %s\n", e.Ready.ServiceName) - cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl) - cmd.Printf(" Domain: %s\n", e.Ready.Domain) - cmd.Printf(" Protocol: %s\n", exposeProtocol) - cmd.Printf(" Port: %d\n", port) - cmd.Println() - cmd.Println("Press Ctrl+C to stop exposing.") - return nil - default: + ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready) + if !ok { return fmt.Errorf("unexpected expose event: %T", event.Event) } + printExposeReady(cmd, ready.Ready, port) + return nil +} + +func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) { + cmd.Println("Service exposed successfully!") + cmd.Printf(" Name: %s\n", r.ServiceName) + if r.ServiceUrl != "" { + cmd.Printf(" URL: %s\n", r.ServiceUrl) + } + if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) { + cmd.Printf(" Domain: %s\n", r.Domain) + } + cmd.Printf(" Protocol: %s\n", exposeProtocol) + cmd.Printf(" Internal: %d\n", port) + if isClusterProtocol(exposeProtocol) { + cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port))) + } + if r.PortAutoAssigned && exposeExternalPort != 0 { + cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort) + } + cmd.Println() + cmd.Println("Press Ctrl+C to stop exposing.") } func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error { diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go index 5e656650b..ee12326db 100644 --- a/client/cmd/signer/artifactkey.go +++ b/client/cmd/signer/artifactkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go index 881be9367..7c02323dc 100644 --- a/client/cmd/signer/artifactsign.go +++ b/client/cmd/signer/artifactsign.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go index 1d84b65c3..5ff636dcb 100644 --- a/client/cmd/signer/revocation.go +++ b/client/cmd/signer/revocation.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go index 78ac36b41..eae0da84d 100644 --- a/client/cmd/signer/rootkey.go +++ b/client/cmd/signer/rootkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/up.go b/client/cmd/up.go index 9559287d5..f5766522a 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r, false) + connectClient := internal.NewConnectClient(ctx, config, r) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go index 977875093..0b197f4c5 100644 --- a/client/cmd/update_supported.go +++ b/client/cmd/update_supported.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/util" ) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4fbe0eada..21043cf96 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -202,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } - client := internal.NewConnectClient(ctx, c.config, c.recorder, false) + client := internal.NewConnectClient(ctx, c.config, c.recorder) client.SetSyncResponsePersistence(true) // either startup error (permanent backoff err) or nil err (successful engine up) diff --git a/client/internal/connect.go b/client/internal/connect.go index 68a0cb8da..ccd7b6c33 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -27,8 +27,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/updatemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" @@ -44,13 +44,13 @@ import ( ) type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - doInitialAutoUpdate bool + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status - engine *Engine - engineMutex sync.Mutex + engine *Engine + engineMutex sync.Mutex + updateManager *updater.Manager persistSyncResponse bool } @@ -59,17 +59,19 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - doInitialAutoUpdate: doInitalAutoUpdate, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + engineMutex: sync.Mutex{}, } } +func (c *ConnectClient) SetUpdateManager(um *updater.Manager) { + c.updateManager = um +} + // Run with main logic. func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error { return c.run(MobileDependency{}, runningChan, logPath) @@ -187,14 +189,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan stateManager := statemanager.New(path) stateManager.RegisterState(&sshconfig.ShutdownState{}) - updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) - if err == nil { - updateManager.CheckUpdateSuccess(c.ctx) + if c.updateManager != nil { + c.updateManager.CheckUpdateSuccess(c.ctx) + } - inst := installer.New() - if err := inst.CleanUpInstallerFiles(); err != nil { - log.Errorf("failed to clean up temporary installer file: %v", err) - } + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) } defer c.statusRecorder.ClientStop() @@ -308,7 +309,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) + engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmClient, + RelayManager: relayManager, + StatusRecorder: c.statusRecorder, + Checks: checks, + StateManager: stateManager, + UpdateManager: c.updateManager, + }, mobileDependency) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine c.engineMutex.Unlock() @@ -318,15 +327,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { - // AutoUpdate will be true when the user click on "Connect" menu on the UI - if c.doInitialAutoUpdate { - log.Infof("start engine by ui, run auto-update check") - c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) - c.doInitialAutoUpdate = false - } - } - log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 0f8243e7a..f0f399bef 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -27,7 +27,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b374bcc6a..a67a23945 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability() {} +func (d *Resolver) ProbeAvailability(context.Context) {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 179517bbd..6ca4f7957 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -104,12 +104,16 @@ type DefaultServer struct { statusRecorder *peer.Status stateManager *statemanager.Manager + + probeMu sync.Mutex + probeCancel context.CancelFunc + probeWg sync.WaitGroup } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability() + ProbeAvailability(context.Context) ID() types.HandlerID } @@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { + s.probeMu.Lock() + if s.probeCancel != nil { + s.probeCancel() + } s.ctxCancel() + s.probeMu.Unlock() + s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string { } // ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds +// and deactivates the group if no server responds. +// If a previous probe is still running, it will be cancelled before starting a new one. func (s *DefaultServer) ProbeAvailability() { if val := os.Getenv(envSkipDNSProbe); val != "" { skipProbe, err := strconv.ParseBool(val) @@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() { } } - var wg sync.WaitGroup - for _, mux := range s.dnsMuxMap { - wg.Add(1) - go func(mux handlerWithStop) { - defer wg.Done() - mux.ProbeAvailability() - }(mux.handler) + s.probeMu.Lock() + + // don't start probes on a stopped server + if s.ctx.Err() != nil { + s.probeMu.Unlock() + return } + + // cancel any running probe + if s.probeCancel != nil { + s.probeCancel() + s.probeCancel = nil + } + + // wait for the previous probe goroutines to finish while holding + // the mutex so no other caller can start a new probe concurrently + s.probeWg.Wait() + + // start a new probe + probeCtx, probeCancel := context.WithCancel(s.ctx) + s.probeCancel = probeCancel + + s.probeWg.Add(1) + defer s.probeWg.Done() + + // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. + s.mux.Lock() + handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) + for _, mux := range s.dnsMuxMap { + handlers = append(handlers, mux.handler) + } + s.mux.Unlock() + + var wg sync.WaitGroup + for _, handler := range handlers { + wg.Add(1) + go func(h handlerWithStop) { + defer wg.Done() + h.ProbeAvailability(probeCtx) + }(handler) + } + + s.probeMu.Unlock() + wg.Wait() + probeCancel() } func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3606d48b9..d3b0c250d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1065,7 +1065,7 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 375f6df1c..18128a942 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -65,6 +65,7 @@ type upstreamResolverBase struct { mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration + wg sync.WaitGroup deactivate func(error) reactivate func() @@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() + + u.mutex.Lock() + u.wg.Wait() + u.mutex.Unlock() + } // ServeDNS handles a DNS request @@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string { // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { u.mutex.Lock() defer u.mutex.Unlock() - select { - case <-u.ctx.Done(): - return - default: - } - // avoid probe if upstreams could resolve at least one query if u.successCount.Load() > 0 { return @@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() { var mu sync.Mutex var wg sync.WaitGroup - var errors *multierror.Error + var errs *multierror.Error for _, upstream := range u.upstreamServers { - upstream := upstream - wg.Add(1) - go func() { + go func(upstream netip.AddrPort) { defer wg.Done() - err := u.testNameserver(upstream, 500*time.Millisecond) + err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) if err != nil { - errors = multierror.Append(errors, err) + mu.Lock() + errs = multierror.Append(errs, err) + mu.Unlock() log.Warnf("probing upstream nameserver %s: %s", upstream, err) return } mu.Lock() - defer mu.Unlock() success = true - }() + mu.Unlock() + }(upstream) } wg.Wait() + select { + case <-ctx.Done(): + return + case <-u.ctx.Done(): + return + default: + } + // didn't find a working upstream server, let's disable and try later if !success { - u.disable(errors.ErrorOrNil()) + u.disable(errs.ErrorOrNil()) if u.statusRecorder == nil { return @@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream, probeTimeout); err != nil { + if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -364,7 +372,9 @@ func (u *upstreamResolverBase) waitUntilResponse() { log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) u.successCount.Add(1) u.reactivate() + u.mutex.Lock() u.disabled = false + u.mutex.Unlock() } // isTimeout returns true if the given error is a network timeout error. @@ -387,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) { u.successCount.Store(0) u.deactivate(err) u.disabled = true - go u.waitUntilResponse() + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.waitUntilResponse() + }() } func (u *upstreamResolverBase) upstreamServersString() string { @@ -398,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string { return strings.Join(servers, ", ") } -func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(u.ctx, timeout) +func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { + mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) defer cancel() + if externalCtx != nil { + stop2 := context.AfterFunc(externalCtx, cancel) + defer stop2() + } + r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - _, _, err := u.upstreamClient.exchange(ctx, server.String(), r) + _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) return err } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 8b06e4475..ab164c30b 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ProbeAvailability() + resolver.ProbeAvailability(context.TODO()) if !failed { t.Errorf("expected that resolving was deactivated") diff --git a/client/internal/engine.go b/client/internal/engine.go index b0ae841f8..fd3bdf7af 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -51,7 +51,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/jobexec" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" @@ -79,7 +79,6 @@ const ( var ErrResetConnection = fmt.Errorf("reset connection") -// EngineConfig is a config for the Engine type EngineConfig struct { WgPort int WgIfaceName string @@ -141,6 +140,17 @@ type EngineConfig struct { LogPath string } +// EngineServices holds the external service dependencies required by the Engine. +type EngineServices struct { + SignalClient signal.Client + MgmClient mgm.Client + RelayManager *relayClient.Manager + StatusRecorder *peer.Status + Checks []*mgmProto.Checks + StateManager *statemanager.Manager + UpdateManager *updater.Manager +} + // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. type Engine struct { // signal is a Signal Service client @@ -209,7 +219,7 @@ type Engine struct { flowManager nftypes.FlowManager // auto-update - updateManager *updatemanager.Manager + updateManager *updater.Manager // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor @@ -239,22 +249,17 @@ type localIpUpdater interface { func NewEngine( clientCtx context.Context, clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, config *EngineConfig, + services EngineServices, mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, - stateManager *statemanager.Manager, ) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, - signal: signalClient, - signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), - mgmClient: mgmClient, - relayManager: relayManager, + signal: services.SignalClient, + signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), + mgmClient: services.MgmClient, + relayManager: services.RelayManager, peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, @@ -262,11 +267,12 @@ func NewEngine( STUNs: []*stun.URI{}, TURNs: []*stun.URI{}, networkSerial: 0, - statusRecorder: statusRecorder, - stateManager: stateManager, - checks: checks, + statusRecorder: services.StatusRecorder, + stateManager: services.StateManager, + checks: services.Checks, probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), jobExecutor: jobexec.NewExecutor(), + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -309,7 +315,7 @@ func (e *Engine) Stop() error { } if e.updateManager != nil { - e.updateManager.Stop() + e.updateManager.SetDownloadOnly() } log.Info("cleaning up status recorder states") @@ -559,13 +565,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } -func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - - e.handleAutoUpdateVersion(autoUpdateSettings, true) -} - func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -793,39 +792,22 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } -func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + if e.updateManager == nil { + return + } + if autoUpdateSettings == nil { return } - disabled := autoUpdateSettings.Version == disableAutoUpdate - - // stop and cleanup if disabled - if e.updateManager != nil && disabled { - log.Infof("auto-update is disabled, stopping update manager") - e.updateManager.Stop() - e.updateManager = nil + if autoUpdateSettings.Version == disableAutoUpdate { + log.Infof("auto-update is disabled") + e.updateManager.SetDownloadOnly() return } - // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup - if !autoUpdateSettings.AlwaysUpdate && !initialCheck { - log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") - return - } - - // Start manager if needed - if e.updateManager == nil { - log.Infof("starting auto-update manager") - updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) - if err != nil { - return - } - e.updateManager = updateManager - e.updateManager.Start(e.ctx) - } - log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) - e.updateManager.SetVersion(autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate) } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { @@ -842,7 +824,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { } if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { - e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) } if update.GetNetbirdConfig() != nil { @@ -1315,8 +1297,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // Test received (upstream) servers for availability right away instead of upon usage. // If no server of a server group responds this will disable the respective handler and retry later. - e.dnsServer.ProbeAvailability() - + go e.dnsServer.ProbeAvailability() return nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 012c8ad6e..f9e7f8fa0 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine( ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, &EngineConfig{ WgIfaceName: "utun101", WgAddr: "100.64.0.1/24", @@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) { MTU: iface.DefaultMTU, SSHKey: sshKey, }, + EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil, - nil, ) engine.dnsServer = &dns.MockServer{ @@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun102", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) { return nil } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{SyncFunc: syncFunc}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { @@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) @@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil + e, err := NewEngine(ctx, cancel, conf, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmtClient, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}), nil e.ctx = ctx return e, err } diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index 8cd93685e..c59a1a7bd 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -12,9 +12,10 @@ const renewTimeout = 10 * time.Second // Response holds the response from exposing a service. type Response struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } type Request struct { @@ -25,6 +26,7 @@ type Request struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ManagementClient interface { diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go index 7e12d0513..bff4f2ce7 100644 --- a/client/internal/expose/request.go +++ b/client/internal/expose/request.go @@ -15,6 +15,7 @@ func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { UserGroups: req.UserGroups, Domain: req.Domain, NamePrefix: req.NamePrefix, + ListenPort: uint16(req.ListenPort), } } @@ -27,13 +28,15 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest { Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: req.ListenPort, } } func fromClientExposeResponse(response *mgm.ExposeResponse) *Response { return &Response{ - ServiceName: response.ServiceName, - Domain: response.Domain, - ServiceURL: response.ServiceURL, + ServiceName: response.ServiceName, + Domain: response.Domain, + ServiceURL: response.ServiceURL, + PortAutoAssigned: response.PortAutoAssigned, } } diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go deleted file mode 100644 index 20ddec10d..000000000 --- a/client/internal/updatemanager/manager_test.go +++ /dev/null @@ -1,214 +0,0 @@ -//go:build windows || darwin - -package updatemanager - -import ( - "context" - "fmt" - "path" - "testing" - "time" - - v "github.com/hashicorp/go-version" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -type versionUpdateMock struct { - latestVersion *v.Version - onUpdate func() -} - -func (v versionUpdateMock) StopWatch() {} - -func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { - return false -} - -func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { - v.onUpdate = updateFn -} - -func (v versionUpdateMock) LatestVersion() *v.Version { - return v.latestVersion -} - -func (v versionUpdateMock) StartFetcher() {} - -func Test_LatestVersion(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - initialLatestVersion *v.Version - latestVersion *v.Version - shouldUpdateInit bool - shouldUpdateLater bool - }{ - { - name: "Should only trigger update once due to time between triggers being < 5 Minutes", - daemonVersion: "1.0.0", - initialLatestVersion: v.Must(v.NewSemver("1.0.1")), - latestVersion: v.Must(v.NewSemver("1.0.2")), - shouldUpdateInit: true, - shouldUpdateLater: false, - }, - { - name: "Shouldn't update initially, but should update as soon as latest version is fetched", - daemonVersion: "1.0.0", - initialLatestVersion: nil, - latestVersion: v.Must(v.NewSemver("1.0.1")), - shouldUpdateInit: false, - shouldUpdateLater: true, - }, - } - - for idx, c := range testMatrix { - mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = mockUpdate - - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion("latest") - var triggeredInit bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.initialLatestVersion.String() { - t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) - } - triggeredInit = true - case <-time.After(10 * time.Millisecond): - triggeredInit = false - } - if triggeredInit != c.shouldUpdateInit { - t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) - } - - mockUpdate.latestVersion = c.latestVersion - mockUpdate.onUpdate() - - var triggeredLater bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } - triggeredLater = true - case <-time.After(10 * time.Millisecond): - triggeredLater = false - } - if triggeredLater != c.shouldUpdateLater { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) - } - - m.Stop() - } -} - -func Test_HandleUpdate(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - latestVersion *v.Version - expectedVersion string - shouldUpdate bool - }{ - { - name: "Update to a specific version should update regardless of if latestVersion is available yet", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.56.0", - shouldUpdate: true, - }, - { - name: "Update to specific version should not update if version matches", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.55.0", - shouldUpdate: false, - }, - { - name: "Update to specific version should not update if current version is newer", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.54.0", - shouldUpdate: false, - }, - { - name: "Update to latest version should update if latest is newer", - daemonVersion: "0.55.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: true, - }, - { - name: "Update to latest version should not update if latest == current", - daemonVersion: "0.56.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if daemon version is invalid", - daemonVersion: "development", - latestVersion: v.Must(v.NewSemver("1.0.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expecting latest and latest version is unavailable", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expected version is invalid", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "development", - shouldUpdate: false, - }, - } - for idx, c := range testMatrix { - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = &versionUpdateMock{latestVersion: c.latestVersion} - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion(c.expectedVersion) - - var updateTriggered bool - select { - case targetVersion := <-targetVersionChan: - if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) - } - updateTriggered = true - case <-time.After(10 * time.Millisecond): - updateTriggered = false - } - - if updateTriggered != c.shouldUpdate { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) - } - m.Stop() - } -} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go deleted file mode 100644 index 4e87c2d77..000000000 --- a/client/internal/updatemanager/manager_unsupported.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build !windows && !darwin - -package updatemanager - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -// Manager is a no-op stub for unsupported platforms -type Manager struct{} - -// NewManager returns a no-op manager for unsupported platforms -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - return nil, fmt.Errorf("update manager is not supported on this platform") -} - -// CheckUpdateSuccess is a no-op on unsupported platforms -func (m *Manager) CheckUpdateSuccess(ctx context.Context) { - // no-op -} - -// Start is a no-op on unsupported platforms -func (m *Manager) Start(ctx context.Context) { - // no-op -} - -// SetVersion is a no-op on unsupported platforms -func (m *Manager) SetVersion(expectedVersion string) { - // no-op -} - -// Stop is a no-op on unsupported platforms -func (m *Manager) Stop() { - // no-op -} diff --git a/client/internal/updatemanager/doc.go b/client/internal/updater/doc.go similarity index 93% rename from client/internal/updatemanager/doc.go rename to client/internal/updater/doc.go index 54d1bdeab..e1924aa43 100644 --- a/client/internal/updatemanager/doc.go +++ b/client/internal/updater/doc.go @@ -1,4 +1,4 @@ -// Package updatemanager provides automatic update management for the NetBird client. +// Package updater provides automatic update management for the NetBird client. // It monitors for new versions, handles update triggers from management server directives, // and orchestrates the download and installation of client updates. // @@ -32,4 +32,4 @@ // // This enables verification of successful updates and appropriate user notification // after the client restarts with the new version. -package updatemanager +package updater diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updater/downloader/downloader.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader.go rename to client/internal/updater/downloader/downloader.go diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updater/downloader/downloader_test.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader_test.go rename to client/internal/updater/downloader/downloader_test.go diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updater/installer/binary_nowindows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_nowindows.go rename to client/internal/updater/installer/binary_nowindows.go diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updater/installer/binary_windows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_windows.go rename to client/internal/updater/installer/binary_windows.go diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updater/installer/doc.go similarity index 100% rename from client/internal/updatemanager/installer/doc.go rename to client/internal/updater/installer/doc.go diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updater/installer/installer.go similarity index 100% rename from client/internal/updatemanager/installer/installer.go rename to client/internal/updater/installer/installer.go diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updater/installer/installer_common.go similarity index 97% rename from client/internal/updatemanager/installer/installer_common.go rename to client/internal/updater/installer/installer_common.go index 03378d55f..8e44bee82 100644 --- a/client/internal/updatemanager/installer/installer_common.go +++ b/client/internal/updater/installer/installer_common.go @@ -16,8 +16,8 @@ import ( goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/downloader" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) type Installer struct { diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updater/installer/installer_log_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_darwin.go rename to client/internal/updater/installer/installer_log_darwin.go diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updater/installer/installer_log_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_windows.go rename to client/internal/updater/installer/installer_log_windows.go diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updater/installer/installer_run_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_darwin.go rename to client/internal/updater/installer/installer_run_darwin.go diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updater/installer/installer_run_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_windows.go rename to client/internal/updater/installer/installer_run_windows.go diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updater/installer/log.go similarity index 100% rename from client/internal/updatemanager/installer/log.go rename to client/internal/updater/installer/log.go diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updater/installer/procattr_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_darwin.go rename to client/internal/updater/installer/procattr_darwin.go diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updater/installer/procattr_windows.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_windows.go rename to client/internal/updater/installer/procattr_windows.go diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updater/installer/repourl_dev.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_dev.go rename to client/internal/updater/installer/repourl_dev.go diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updater/installer/repourl_prod.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_prod.go rename to client/internal/updater/installer/repourl_prod.go diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updater/installer/result.go similarity index 98% rename from client/internal/updatemanager/installer/result.go rename to client/internal/updater/installer/result.go index 03d08d527..526c3eb53 100644 --- a/client/internal/updatemanager/installer/result.go +++ b/client/internal/updater/installer/result.go @@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error { func (rh *ResultHandler) cleanup() error { err := os.Remove(rh.resultFile) - if err != nil && !os.IsNotExist(err) { + if err != nil { + if os.IsNotExist(err) { + return nil + } return err } log.Debugf("delete installer result file: %s", rh.resultFile) diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updater/installer/types.go similarity index 100% rename from client/internal/updatemanager/installer/types.go rename to client/internal/updater/installer/types.go diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updater/installer/types_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/types_darwin.go rename to client/internal/updater/installer/types_darwin.go diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updater/installer/types_windows.go similarity index 100% rename from client/internal/updatemanager/installer/types_windows.go rename to client/internal/updater/installer/types_windows.go diff --git a/client/internal/updatemanager/manager.go b/client/internal/updater/manager.go similarity index 52% rename from client/internal/updatemanager/manager.go rename to client/internal/updater/manager.go index eae11de56..dfcb93177 100644 --- a/client/internal/updatemanager/manager.go +++ b/client/internal/updater/manager.go @@ -1,12 +1,9 @@ -//go:build windows || darwin - -package updatemanager +package updater import ( "context" "errors" "fmt" - "runtime" "sync" "time" @@ -15,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -41,6 +38,9 @@ type Manager struct { statusRecorder *peer.Status stateManager *statemanager.Manager + downloadOnly bool // true when no enforcement from management; notifies UI to download latest + forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly + lastTrigger time.Time mgmUpdateChan chan struct{} updateChannel chan struct{} @@ -53,37 +53,38 @@ type Manager struct { expectedVersion *v.Version updateToLatestVersion bool - // updateMutex protect update and expectedVersion fields + pendingVersion *v.Version + + // updateMutex protects update, expectedVersion, updateToLatestVersion, + // downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields updateMutex sync.Mutex - triggerUpdateFn func(context.Context, string) error + // installMutex and installing guard against concurrent installation attempts + installMutex sync.Mutex + installing bool + + // protect to start the service multiple times + mu sync.Mutex + + autoUpdateSupported func() bool } -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - if runtime.GOOS == "darwin" { - isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() - if isBrew { - log.Warnf("auto-update disabled on Home Brew installation") - return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") - } - } - return newManager(statusRecorder, stateManager) -} - -func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { +// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted. +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager { manager := &Manager{ - statusRecorder: statusRecorder, - stateManager: stateManager, - mgmUpdateChan: make(chan struct{}, 1), - updateChannel: make(chan struct{}, 1), - currentVersion: version.NetbirdVersion(), - update: version.NewUpdate("nb/client"), + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + downloadOnly: true, + autoUpdateSupported: isAutoUpdateSupported, } - manager.triggerUpdateFn = manager.triggerUpdate stateManager.RegisterState(&UpdateState{}) - return manager, nil + return manager } // CheckUpdateSuccess checks if the update was successful and send a notification. @@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) { } func (m *Manager) Start(ctx context.Context) { + log.Infof("starting update manager") + m.mu.Lock() + defer m.mu.Unlock() if m.cancel != nil { - log.Errorf("Manager already started") return } @@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) { m.cancel = cancel m.wg.Add(1) - go m.updateLoop(ctx) + go func() { + defer m.wg.Done() + m.updateLoop(ctx) + }() } -func (m *Manager) SetVersion(expectedVersion string) { - log.Infof("set expected agent version for upgrade: %s", expectedVersion) - if m.cancel == nil { - log.Errorf("manager not started") +func (m *Manager) SetDownloadOnly() { + m.updateMutex.Lock() + m.downloadOnly = true + m.forceUpdate = false + m.expectedVersion = nil + m.updateToLatestVersion = false + m.lastTrigger = time.Time{} + m.updateMutex.Unlock() + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) { + log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate) + + if !m.autoUpdateSupported() { + log.Warnf("auto-update not supported on this platform") return } @@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) { log.Errorf("empty expected version provided") m.expectedVersion = nil m.updateToLatestVersion = false + m.downloadOnly = true return } @@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) { m.updateToLatestVersion = false } + m.lastTrigger = time.Time{} + m.downloadOnly = false + m.forceUpdate = forceUpdate + select { case m.mgmUpdateChan <- struct{}{}: default: } } +// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI. +func (m *Manager) Install(ctx context.Context) error { + if !m.autoUpdateSupported() { + return fmt.Errorf("auto-update not supported on this platform") + } + + m.updateMutex.Lock() + pending := m.pendingVersion + m.updateMutex.Unlock() + + if pending == nil { + return fmt.Errorf("no pending version to install") + } + + return m.tryInstall(ctx, pending) +} + +// tryInstall ensures only one installation runs at a time. Concurrent callers +// receive an error immediately rather than queuing behind a running install. +func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error { + m.installMutex.Lock() + if m.installing { + m.installMutex.Unlock() + return fmt.Errorf("installation already in progress") + } + m.installing = true + m.installMutex.Unlock() + + defer func() { + m.installMutex.Lock() + m.installing = false + m.installMutex.Unlock() + }() + + return m.install(ctx, targetVersion) +} + +// NotifyUI re-publishes the current update state to a newly connected UI client. +// Only needed for download-only mode where the latest version is already cached +// NotifyUI re-publishes the current update state so a newly connected UI gets the info. +func (m *Manager) NotifyUI() { + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + downloadOnly := m.downloadOnly + pendingVersion := m.pendingVersion + latestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + if downloadOnly { + if latestVersion == nil { + return + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) { + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": latestVersion.String()}, + ) + return + } + + if pendingVersion != nil { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"}, + ) + } +} + +// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it. func (m *Manager) Stop() { if m.cancel == nil { return @@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() { } func (m *Manager) updateLoop(ctx context.Context) { - defer m.wg.Done() - for { select { case <-ctx.Done(): @@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) { return } - expectedVersion := m.expectedVersion - useLatest := m.updateToLatestVersion + downloadOnly := m.downloadOnly + forceUpdate := m.forceUpdate curLatestVersion := m.update.LatestVersion() - m.updateMutex.Unlock() switch { - // Resolve "latest" to actual version - case useLatest: + // Download-only mode or resolve "latest" to actual version + case downloadOnly, m.updateToLatestVersion: if curLatestVersion == nil { log.Tracef("latest version not fetched yet") + m.updateMutex.Unlock() return } updateVersion = curLatestVersion - // Update to specific version - case expectedVersion != nil: - updateVersion = expectedVersion + // Install to specific version + case m.expectedVersion != nil: + updateVersion = m.expectedVersion default: log.Debugf("no expected version information set") + m.updateMutex.Unlock() return } log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) - if !m.shouldUpdate(updateVersion) { + if !m.shouldUpdate(updateVersion, forceUpdate) { + m.updateMutex.Unlock() return } m.lastTrigger = time.Now() - log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) - m.statusRecorder.PublishEvent( - cProto.SystemEvent_CRITICAL, - cProto.SystemEvent_SYSTEM, - "Automatically updating client", - "Your client version is older than auto-update version set in Management, updating client now.", - nil, - ) + log.Infof("new version available: %s", updateVersion) + + if !downloadOnly && !forceUpdate { + m.pendingVersion = updateVersion + } + m.updateMutex.Unlock() + + if downloadOnly { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String()}, + ) + return + } + + if forceUpdate { + if err := m.tryInstall(ctx, updateVersion); err != nil { + log.Errorf("force update failed: %v", err) + } + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"}, + ) +} + +func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Updating client", + "Installing update now.", + nil, + ) m.statusRecorder.PublishEvent( cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM, "", "", - map[string]string{"progress_window": "show", "version": updateVersion.String()}, + map[string]string{"progress_window": "show", "version": pendingVersion.String()}, ) updateState := UpdateState{ PreUpdateVersion: m.currentVersion, - TargetVersion: updateVersion.String(), + TargetVersion: pendingVersion.String(), } - if err := m.stateManager.UpdateState(updateState); err != nil { log.Warnf("failed to update state: %v", err) } else { @@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { } } - if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { - log.Errorf("Error triggering auto-update: %v", err) + inst := installer.New() + if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil { + log.Errorf("error triggering update: %v", err) m.statusRecorder.PublishEvent( cProto.SystemEvent_ERROR, cProto.SystemEvent_SYSTEM, @@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { fmt.Sprintf("Auto-update failed: %v", err), nil, ) + return err } + return nil } // loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. @@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e return updateState, nil } -func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { +func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool { if m.currentVersion == developmentVersion { log.Debugf("skipping auto-update, running development version") return false @@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { return false } - if time.Since(m.lastTrigger) < 5*time.Minute { - log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute { + log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) return false } @@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string { result := installer.NewResultHandler(inst.TempDir()) return result.GetErrorResultReason() } - -func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { - inst := installer.New() - return inst.RunInstallation(ctx, targetVersion) -} diff --git a/client/internal/updater/manager_linux_test.go b/client/internal/updater/manager_linux_test.go new file mode 100644 index 000000000..b05dd7e7d --- /dev/null +++ b/client/internal/updater/manager_linux_test.go @@ -0,0 +1,111 @@ +//go:build !windows && !darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// On Linux, only Mode 1 (downloadOnly) is supported. +// SetVersion is a no-op because auto-update installation is not supported. + +func Test_LatestVersion_Linux(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't notify initially, but should notify as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_SetVersion_NoOp_Linux(t *testing.T) { + // On Linux, SetVersion should be a no-op — no events fired + tmpFile := path.Join(t.TempDir(), "update-test-noop.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetVersion("1.0.1", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver != "" { + t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver) + } + + m.Stop() +} diff --git a/client/internal/updater/manager_test.go b/client/internal/updater/manager_test.go new file mode 100644 index 000000000..107dca2b3 --- /dev/null +++ b/client/internal/updater/manager_test.go @@ -0,0 +1,227 @@ +//go:build windows || darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + cProto "github.com/netbirdio/netbird/client/proto" +) + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion("latest", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Install to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Install to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Install to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Install to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Install to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion(c.expectedVersion, false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + updateTriggered := ver != "" + + if updateTriggered { + if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver) + } + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} + +func Test_EnforcedMetadata(t *testing.T) { + // Mode 1 (downloadOnly): no enforced metadata + tmpFile := path.Join(t.TempDir(), "update-test-mode1.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 1: expected new_version_available event") + } + if enforced { + t.Error("Mode 1: expected no enforced metadata") + } + m.Stop() + + // Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install + tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json") + recorder2 := peer.NewRecorder("") + sub2 := recorder2.SubscribeToEvents() + defer recorder2.UnsubscribeFromEvents(sub2) + + m2 := NewManager(recorder2, statemanager.New(tmpFile2)) + m2.update = &versionUpdateMock{latestVersion: nil} + m2.currentVersion = "1.0.0" + m2.autoUpdateSupported = func() bool { return true } + m2.Start(context.Background()) + m2.SetVersion("1.0.1", false) + + ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 2: expected new_version_available event") + } + if !enforced2 { + t.Error("Mode 2: expected enforced metadata") + } + m2.Stop() +} + +// ensure the proto import is used +var _ = cProto.SystemEvent_INFO diff --git a/client/internal/updater/manager_test_helpers_test.go b/client/internal/updater/manager_test_helpers_test.go new file mode 100644 index 000000000..c7faee1f4 --- /dev/null +++ b/client/internal/updater/manager_test_helpers_test.go @@ -0,0 +1,56 @@ +package updater + +import ( + "strconv" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (m versionUpdateMock) StopWatch() {} + +func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + m.onUpdate = updateFn +} + +func (m versionUpdateMock) LatestVersion() *v.Version { + return m.latestVersion +} + +func (m versionUpdateMock) StartFetcher() {} + +// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout. +func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) { + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case event, ok := <-sub.Events(): + if !ok { + return "", false + } + if val, ok := event.Metadata["new_version_available"]; ok { + enforced := false + if raw, ok := event.Metadata["enforced"]; ok { + if parsed, err := strconv.ParseBool(raw); err == nil { + enforced = parsed + } + } + return val, enforced + } + case <-timer.C: + return "", false + } + } +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updater/reposign/artifact.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact.go rename to client/internal/updater/reposign/artifact.go diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updater/reposign/artifact_test.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact_test.go rename to client/internal/updater/reposign/artifact_test.go diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updater/reposign/certs/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certs/root-pub.pem rename to client/internal/updater/reposign/certs/root-pub.pem diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updater/reposign/certsdev/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certsdev/root-pub.pem rename to client/internal/updater/reposign/certsdev/root-pub.pem diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updater/reposign/doc.go similarity index 100% rename from client/internal/updatemanager/reposign/doc.go rename to client/internal/updater/reposign/doc.go diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updater/reposign/embed_dev.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_dev.go rename to client/internal/updater/reposign/embed_dev.go diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updater/reposign/embed_prod.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_prod.go rename to client/internal/updater/reposign/embed_prod.go diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updater/reposign/key.go similarity index 100% rename from client/internal/updatemanager/reposign/key.go rename to client/internal/updater/reposign/key.go diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updater/reposign/key_test.go similarity index 100% rename from client/internal/updatemanager/reposign/key_test.go rename to client/internal/updater/reposign/key_test.go diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updater/reposign/revocation.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation.go rename to client/internal/updater/reposign/revocation.go diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updater/reposign/revocation_test.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation_test.go rename to client/internal/updater/reposign/revocation_test.go diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updater/reposign/root.go similarity index 100% rename from client/internal/updatemanager/reposign/root.go rename to client/internal/updater/reposign/root.go diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updater/reposign/root_test.go similarity index 100% rename from client/internal/updatemanager/reposign/root_test.go rename to client/internal/updater/reposign/root_test.go diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updater/reposign/signature.go similarity index 100% rename from client/internal/updatemanager/reposign/signature.go rename to client/internal/updater/reposign/signature.go diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updater/reposign/signature_test.go similarity index 100% rename from client/internal/updatemanager/reposign/signature_test.go rename to client/internal/updater/reposign/signature_test.go diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updater/reposign/verify.go similarity index 98% rename from client/internal/updatemanager/reposign/verify.go rename to client/internal/updater/reposign/verify.go index 0af2a8c9e..f64b26a30 100644 --- a/client/internal/updatemanager/reposign/verify.go +++ b/client/internal/updater/reposign/verify.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updater/downloader" ) const ( diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updater/reposign/verify_test.go similarity index 100% rename from client/internal/updatemanager/reposign/verify_test.go rename to client/internal/updater/reposign/verify_test.go diff --git a/client/internal/updater/supported_darwin.go b/client/internal/updater/supported_darwin.go new file mode 100644 index 000000000..b27754366 --- /dev/null +++ b/client/internal/updater/supported_darwin.go @@ -0,0 +1,22 @@ +package updater + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updater/installer" +) + +func isAutoUpdateSupported() bool { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + isBrew := !installer.TypeOfInstaller(ctx).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Homebrew installation") + return false + } + return true +} diff --git a/client/internal/updater/supported_other.go b/client/internal/updater/supported_other.go new file mode 100644 index 000000000..e09e8c3a3 --- /dev/null +++ b/client/internal/updater/supported_other.go @@ -0,0 +1,7 @@ +//go:build !windows && !darwin + +package updater + +func isAutoUpdateSupported() bool { + return false +} diff --git a/client/internal/updater/supported_windows.go b/client/internal/updater/supported_windows.go new file mode 100644 index 000000000..0c28878c7 --- /dev/null +++ b/client/internal/updater/supported_windows.go @@ -0,0 +1,5 @@ +package updater + +func isAutoUpdateSupported() bool { + return true +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updater/update.go similarity index 90% rename from client/internal/updatemanager/update.go rename to client/internal/updater/update.go index 875b50b49..3056c77e1 100644 --- a/client/internal/updatemanager/update.go +++ b/client/internal/updater/update.go @@ -1,4 +1,4 @@ -package updatemanager +package updater import v "github.com/hashicorp/go-version" diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index aafef41d3..3e2da7f4e 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -160,7 +160,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 3879beba3..fa0b2f93b 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.33.3 +// protoc v6.33.1 // source: daemon.proto package proto @@ -95,6 +95,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -104,12 +105,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -945,7 +948,6 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` - AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -994,13 +996,6 @@ func (x *UpRequest) GetUsername() string { return "" } -func (x *UpRequest) GetAutoUpdate() bool { - if x != nil && x.AutoUpdate != nil { - return *x.AutoUpdate - } - return false -} - type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5032,6 +5027,94 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +type TriggerUpdateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateRequest) Reset() { + *x = TriggerUpdateRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateRequest) ProtoMessage() {} + +func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead. +func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +type TriggerUpdateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateResponse) Reset() { + *x = TriggerUpdateResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateResponse) ProtoMessage() {} + +func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead. +func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *TriggerUpdateResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *TriggerUpdateResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer type GetPeerSSHHostKeyRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -5043,7 +5126,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5055,7 +5138,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5068,7 +5151,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -5095,7 +5178,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5107,7 +5190,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5120,7 +5203,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -5162,7 +5245,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5174,7 +5257,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5187,7 +5270,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5220,7 +5303,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5232,7 +5315,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5245,7 +5328,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5310,7 +5393,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5322,7 +5405,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5335,7 +5418,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{77} + return file_daemon_proto_rawDescGZIP(), []int{79} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5367,7 +5450,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5379,7 +5462,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5392,7 +5475,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{78} + return file_daemon_proto_rawDescGZIP(), []int{80} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5425,7 +5508,7 @@ type StartCPUProfileRequest struct { func (x *StartCPUProfileRequest) Reset() { *x = StartCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5437,7 +5520,7 @@ func (x *StartCPUProfileRequest) String() string { func (*StartCPUProfileRequest) ProtoMessage() {} func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5450,7 +5533,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{81} } // StartCPUProfileResponse confirms CPU profiling has started @@ -5462,7 +5545,7 @@ type StartCPUProfileResponse struct { func (x *StartCPUProfileResponse) Reset() { *x = StartCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5474,7 +5557,7 @@ func (x *StartCPUProfileResponse) String() string { func (*StartCPUProfileResponse) ProtoMessage() {} func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5487,7 +5570,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{82} } // StopCPUProfileRequest for stopping CPU profiling @@ -5499,7 +5582,7 @@ type StopCPUProfileRequest struct { func (x *StopCPUProfileRequest) Reset() { *x = StopCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5511,7 +5594,7 @@ func (x *StopCPUProfileRequest) String() string { func (*StopCPUProfileRequest) ProtoMessage() {} func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5524,7 +5607,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{81} + return file_daemon_proto_rawDescGZIP(), []int{83} } // StopCPUProfileResponse confirms CPU profiling has stopped @@ -5536,7 +5619,7 @@ type StopCPUProfileResponse struct { func (x *StopCPUProfileResponse) Reset() { *x = StopCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5548,7 +5631,7 @@ func (x *StopCPUProfileResponse) String() string { func (*StopCPUProfileResponse) ProtoMessage() {} func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5561,7 +5644,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{82} + return file_daemon_proto_rawDescGZIP(), []int{84} } type InstallerResultRequest struct { @@ -5572,7 +5655,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5584,7 +5667,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5597,7 +5680,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{83} + return file_daemon_proto_rawDescGZIP(), []int{85} } type InstallerResultResponse struct { @@ -5610,7 +5693,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5622,7 +5705,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5635,7 +5718,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{84} + return file_daemon_proto_rawDescGZIP(), []int{86} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5661,13 +5744,14 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ExposeServiceRequest) Reset() { *x = ExposeServiceRequest{} - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[87] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5679,7 +5763,7 @@ func (x *ExposeServiceRequest) String() string { func (*ExposeServiceRequest) ProtoMessage() {} func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[87] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5692,7 +5776,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{85} + return file_daemon_proto_rawDescGZIP(), []int{87} } func (x *ExposeServiceRequest) GetPort() uint32 { @@ -5744,6 +5828,13 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceEvent struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Event: @@ -5756,7 +5847,7 @@ type ExposeServiceEvent struct { func (x *ExposeServiceEvent) Reset() { *x = ExposeServiceEvent{} - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[88] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5768,7 +5859,7 @@ func (x *ExposeServiceEvent) String() string { func (*ExposeServiceEvent) ProtoMessage() {} func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[88] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5781,7 +5872,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead. func (*ExposeServiceEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{86} + return file_daemon_proto_rawDescGZIP(), []int{88} } func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event { @@ -5811,17 +5902,18 @@ type ExposeServiceEvent_Ready struct { func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {} type ExposeServiceReady struct { - state protoimpl.MessageState `protogen:"open.v1"` - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExposeServiceReady) Reset() { *x = ExposeServiceReady{} - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[89] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5833,7 +5925,7 @@ func (x *ExposeServiceReady) String() string { func (*ExposeServiceReady) ProtoMessage() {} func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[89] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5846,7 +5938,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead. func (*ExposeServiceReady) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{87} + return file_daemon_proto_rawDescGZIP(), []int{89} } func (x *ExposeServiceReady) GetServiceName() string { @@ -5870,6 +5962,13 @@ func (x *ExposeServiceReady) GetDomain() string { return "" } +func (x *ExposeServiceReady) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5880,7 +5979,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[91] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5892,7 +5991,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[91] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6016,16 +6115,12 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"v\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + - "\n" + - "autoUpdate\x18\x03 \x01(\bH\x02R\n" + - "autoUpdate\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_usernameB\r\n" + - "\v_autoUpdate\"\f\n" + + "\t_usernameJ\x04\b\x03\x10\x04\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -6380,7 +6475,11 @@ const file_daemon_proto_rawDesc = "" + "\x12GetFeaturesRequest\"x\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" + + "\x14TriggerUpdateRequest\"M\n" + + "\x15TriggerUpdateResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"<\n" + "\x18GetPeerSSHHostKeyRequest\x12 \n" + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + @@ -6419,7 +6518,7 @@ const file_daemon_proto_rawDesc = "" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + - "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\x87\x02\n" + "\x14ExposeServiceRequest\x12\x12\n" + "\x04port\x18\x01 \x01(\rR\x04port\x122\n" + "\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" + @@ -6429,15 +6528,18 @@ const file_daemon_proto_rawDesc = "" + "userGroups\x12\x16\n" + "\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" + "\vname_prefix\x18\a \x01(\tR\n" + - "namePrefix\"Q\n" + + "namePrefix\x12\x1f\n" + + "\vlisten_port\x18\b \x01(\rR\n" + + "listenPort\"Q\n" + "\x12ExposeServiceEvent\x122\n" + "\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" + - "\x05event\"p\n" + + "\x05event\"\x9e\x01\n" + "\x12ExposeServiceReady\x12!\n" + "\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + - "\x06domain\x18\x03 \x01(\tR\x06domain*b\n" + + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6446,14 +6548,16 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a*S\n" + + "\x05TRACE\x10\a*c\n" + "\x0eExposeProtocol\x12\x0f\n" + "\vEXPOSE_HTTP\x10\x00\x12\x10\n" + "\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" + "\n" + "EXPOSE_TCP\x10\x02\x12\x0e\n" + "\n" + - "EXPOSE_UDP\x10\x032\xac\x15\n" + + "EXPOSE_UDP\x10\x03\x12\x0e\n" + + "\n" + + "EXPOSE_TLS\x10\x042\xfc\x15\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6485,7 +6589,8 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12N\n" + + "\rTriggerUpdate\x12\x1c.daemon.TriggerUpdateRequest\x1a\x1d.daemon.TriggerUpdateResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + @@ -6508,7 +6613,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol @@ -6588,34 +6693,36 @@ var file_daemon_proto_goTypes = []any{ (*LogoutResponse)(nil), // 75: daemon.LogoutResponse (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse - (*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse - (*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest - (*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse - (*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest - (*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse - (*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse - (*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest - (*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent - (*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady - nil, // 93: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range - nil, // 95: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 96: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp + (*TriggerUpdateRequest)(nil), // 78: daemon.TriggerUpdateRequest + (*TriggerUpdateResponse)(nil), // 79: daemon.TriggerUpdateResponse + (*GetPeerSSHHostKeyRequest)(nil), // 80: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 81: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 82: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 83: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 84: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 85: daemon.WaitJWTTokenResponse + (*StartCPUProfileRequest)(nil), // 86: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 87: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 88: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 89: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 90: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 91: daemon.InstallerResultResponse + (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest + (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent + (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady + nil, // 95: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range + nil, // 97: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 98: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6626,8 +6733,8 @@ var file_daemon_proto_depIdxs = []int32{ 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6638,13 +6745,13 @@ var file_daemon_proto_depIdxs = []int32{ 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest @@ -6674,52 +6781,54 @@ var file_daemon_proto_depIdxs = []int32{ 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 72, // [72:108] is the sub-list for method output_type - 36, // [36:72] is the sub-list for method input_type + 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 73, // [73:110] is the sub-list for method output_type + 36, // [36:73] is the sub-list for method input_type 36, // [36:36] is the sub-list for extension type_name 36, // [36:36] is the sub-list for extension extendee 0, // [0:36] is the sub-list for field type_name @@ -6742,8 +6851,8 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[56].OneofWrappers = []any{} file_daemon_proto_msgTypes[58].OneofWrappers = []any{} file_daemon_proto_msgTypes[69].OneofWrappers = []any{} - file_daemon_proto_msgTypes[75].OneofWrappers = []any{} - file_daemon_proto_msgTypes[86].OneofWrappers = []any{ + file_daemon_proto_msgTypes[77].OneofWrappers = []any{} + file_daemon_proto_msgTypes[88].OneofWrappers = []any{ (*ExposeServiceEvent_Ready)(nil), } type x struct{} @@ -6752,7 +6861,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 5, - NumMessages: 91, + NumMessages: 93, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 4dc41d401..89302c8c3 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -85,6 +85,10 @@ service DaemonService { rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + rpc TriggerUpdate(TriggerUpdateRequest) returns (TriggerUpdateResponse) {} + // GetPeerSSHHostKey retrieves SSH host key for a specific peer rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} @@ -226,7 +230,7 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; - optional bool autoUpdate = 3; + reserved 3; } message UpResponse {} @@ -725,6 +729,13 @@ message GetFeaturesResponse{ bool disable_update_settings = 2; } +message TriggerUpdateRequest {} + +message TriggerUpdateResponse { + bool success = 1; + string errorMsg = 2; +} + // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer message GetPeerSSHHostKeyRequest { // peer IP address or FQDN to get SSH host key for @@ -810,6 +821,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -820,6 +832,7 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceEvent { @@ -832,4 +845,5 @@ message ExposeServiceReady { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 4154dce59..e5bd89597 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -64,6 +64,9 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -363,6 +366,15 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) { + out := new(TriggerUpdateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TriggerUpdate", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { out := new(GetPeerSSHHostKeyResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) @@ -508,6 +520,9 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -613,6 +628,9 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TriggerUpdate not implemented") +} func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") } @@ -1157,6 +1175,24 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_TriggerUpdate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TriggerUpdateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/TriggerUpdate", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, req.(*TriggerUpdateRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetPeerSSHHostKeyRequest) if err := dec(in); err != nil { @@ -1419,6 +1455,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "TriggerUpdate", + Handler: _DaemonService_TriggerUpdate_Handler, + }, { MethodName: "GetPeerSSHHostKey", Handler: _DaemonService_GetPeerSSHHostKey_Handler, diff --git a/client/server/event.go b/client/server/event.go index b5c12a3a6..d93151c96 100644 --- a/client/server/event.go +++ b/client/server/event.go @@ -14,6 +14,7 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo }() log.Debug("client subscribed to events") + s.startUpdateManagerForGUI() for { select { diff --git a/client/server/server.go b/client/server/server.go index 69d79d9cd..7c1e70692 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -30,6 +30,8 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -89,6 +91,8 @@ type Server struct { sleepHandler *sleephandler.SleepHandler + updateManager *updater.Manager + jwtCache *jwtCache } @@ -135,6 +139,12 @@ func (s *Server) Start() error { log.Warnf(errRestoreResidualState, err) } + if s.updateManager == nil { + stateMgr := statemanager.New(s.profileManager.GetStatePath()) + s.updateManager = updater.NewManager(s.statusRecorder, stateMgr) + s.updateManager.CheckUpdateSuccess(s.rootCtx) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -192,14 +202,14 @@ func (s *Server) Start() error { s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -207,7 +217,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -236,8 +246,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) - doInitialAutoUpdate = false + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -717,11 +726,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - var doAutoUpdate bool - if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { - doAutoUpdate = true - } - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) s.mutex.Unlock() return s.waitForUp(callerCtx) @@ -1373,9 +1378,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon if err := srv.Send(&proto.ExposeServiceEvent{ Event: &proto.ExposeServiceEvent_Ready{ Ready: &proto.ExposeServiceReady{ - ServiceName: result.ServiceName, - ServiceUrl: result.ServiceURL, - Domain: result.Domain, + ServiceName: result.ServiceName, + ServiceUrl: result.ServiceURL, + Domain: result.Domain, + PortAutoAssigned: result.PortAutoAssigned, }, }, }); err != nil { @@ -1623,9 +1629,10 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { log.Tracef("running client connection") - client := internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) + client := internal.NewConnectClient(ctx, config, statusRecorder) + client.SetUpdateManager(s.updateManager) client.SetSyncResponsePersistence(s.persistSyncResponse) s.mutex.Lock() @@ -1656,6 +1663,14 @@ func (s *Server) checkUpdateSettingsDisabled() bool { return false } +func (s *Server) startUpdateManagerForGUI() { + if s.updateManager == nil { + return + } + s.updateManager.Start(s.rootCtx) + s.updateManager.NotifyUI() +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() diff --git a/client/server/server_connect_test.go b/client/server/server_connect_test.go index 8d31c2ae6..faea7da39 100644 --- a/client/server/server_connect_test.go +++ b/client/server/server_connect_test.go @@ -22,7 +22,7 @@ func newTestServer() *Server { } func newDummyConnectClient(ctx context.Context) *internal.ConnectClient { - return internal.NewConnectClient(ctx, nil, nil, false) + return internal.NewConnectClient(ctx, nil, nil) } // TestConnectSetsClientWithMutex validates that connect() sets s.connectClient diff --git a/client/server/server_test.go b/client/server/server_test.go index 82079c531..6de23d501 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -113,7 +113,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } diff --git a/client/server/triggerupdate.go b/client/server/triggerupdate.go new file mode 100644 index 000000000..ffcb527e7 --- /dev/null +++ b/client/server/triggerupdate.go @@ -0,0 +1,24 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +// TriggerUpdate initiates installation of the pending enforced version. +// It is called when the user clicks the install button in the UI (Mode 2 / enforced update). +func (s *Server) TriggerUpdate(ctx context.Context, _ *proto.TriggerUpdateRequest) (*proto.TriggerUpdateResponse, error) { + if s.updateManager == nil { + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: "update manager not available"}, nil + } + + if err := s.updateManager.Install(ctx); err != nil { + log.Warnf("TriggerUpdate failed: %v", err) + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: err.Error()}, nil + } + + return &proto.TriggerUpdateResponse{Success: true}, nil +} diff --git a/client/server/updateresult.go b/client/server/updateresult.go index 8e00d5062..8d1ef0e5f 100644 --- a/client/server/updateresult.go +++ b/client/server/updateresult.go @@ -5,7 +5,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/proto" ) diff --git a/client/ssh/server/getent_cgo_unix.go b/client/ssh/server/getent_cgo_unix.go new file mode 100644 index 000000000..4afbfc627 --- /dev/null +++ b/client/ssh/server/getent_cgo_unix.go @@ -0,0 +1,24 @@ +//go:build cgo && !osusergo && !windows + +package server + +import "os/user" + +// lookupWithGetent with CGO delegates directly to os/user.Lookup. +// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through +// the NSS stack natively. If it fails, the user truly doesn't exist and +// getent would also fail. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent with CGO delegates directly to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// groupIdsWithFallback with CGO delegates directly to user.GroupIds. +// libc's getgrouplist handles NSS groups natively. +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/getent_nocgo_unix.go b/client/ssh/server/getent_nocgo_unix.go new file mode 100644 index 000000000..314daae4c --- /dev/null +++ b/client/ssh/server/getent_nocgo_unix.go @@ -0,0 +1,74 @@ +//go:build (!cgo || osusergo) && !windows + +package server + +import ( + "os" + "os/user" + "strconv" + + log "github.com/sirupsen/logrus" +) + +// lookupWithGetent looks up a user by name, falling back to getent if os/user fails. +// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users. +// getent goes through the host's NSS stack. +func lookupWithGetent(username string) (*user.User, error) { + u, err := user.Lookup(username) + if err == nil { + return u, nil + } + + stdErr := err + log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err) + + u, _, getentErr := runGetent(username) + if getentErr != nil { + log.Debugf("getent fallback for %q also failed: %v", username, getentErr) + return nil, stdErr + } + + return u, nil +} + +// currentUserWithGetent gets the current user, falling back to getent if os/user fails. +func currentUserWithGetent() (*user.User, error) { + u, err := user.Current() + if err == nil { + return u, nil + } + + stdErr := err + uid := strconv.Itoa(os.Getuid()) + log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err) + + u, _, getentErr := runGetent(uid) + if getentErr != nil { + return nil, stdErr + } + + return u, nil +} + +// groupIdsWithFallback gets group IDs for a user via the id command first, +// falling back to user.GroupIds(). +// NOTE: unlike lookupWithGetent/currentUserWithGetent which try stdlib first, +// this intentionally tries `id -G` first because without CGO, user.GroupIds() +// only reads /etc/group and silently returns incomplete results for NSS users +// (no error, just missing groups). The id command goes through NSS and returns +// the full set. +func groupIdsWithFallback(u *user.User) ([]string, error) { + ids, err := runIdGroups(u.Username) + if err == nil { + return ids, nil + } + + log.Debugf("id -G %q failed, falling back to user.GroupIds(): %v", u.Username, err) + + ids, stdErr := u.GroupIds() + if stdErr != nil { + return nil, stdErr + } + + return ids, nil +} diff --git a/client/ssh/server/getent_test.go b/client/ssh/server/getent_test.go new file mode 100644 index 000000000..5eac2fdbe --- /dev/null +++ b/client/ssh/server/getent_test.go @@ -0,0 +1,172 @@ +package server + +import ( + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLookupWithGetent_CurrentUser(t *testing.T) { + // The current user should always be resolvable on any platform + current, err := user.Current() + require.NoError(t, err) + + u, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Username, u.Username) + assert.Equal(t, current.Uid, u.Uid) + assert.Equal(t, current.Gid, u.Gid) +} + +func TestLookupWithGetent_NonexistentUser(t *testing.T) { + _, err := lookupWithGetent("nonexistent_user_xyzzy_12345") + require.Error(t, err, "should fail for nonexistent user") +} + +func TestCurrentUserWithGetent(t *testing.T) { + stdUser, err := user.Current() + require.NoError(t, err) + + u, err := currentUserWithGetent() + require.NoError(t, err) + assert.Equal(t, stdUser.Uid, u.Uid) + assert.Equal(t, stdUser.Username, u.Username) +} + +func TestGroupIdsWithFallback_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } + } +} + +func TestGetShellFromGetent_CurrentUser(t *testing.T) { + if runtime.GOOS == "windows" { + // Windows stub always returns empty, which is correct + shell := getShellFromGetent("1000") + assert.Empty(t, shell, "Windows stub should return empty") + return + } + + current, err := user.Current() + require.NoError(t, err) + + // getent may not be available on all systems (e.g., macOS without Homebrew getent) + shell := getShellFromGetent(current.Uid) + if shell == "" { + t.Log("getShellFromGetent returned empty, getent may not be available") + return + } + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} + +func TestLookupWithGetent_RootUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("no root user on Windows") + } + + u, err := lookupWithGetent("root") + if err != nil { + t.Skip("root user not available on this system") + } + assert.Equal(t, "0", u.Uid, "root should have UID 0") +} + +// TestIntegration_FullLookupChain exercises the complete user lookup chain +// against the real system, testing that all wrappers (lookupWithGetent, +// currentUserWithGetent, groupIdsWithFallback, getShellFromGetent) produce +// consistent and correct results when composed together. +func TestIntegration_FullLookupChain(t *testing.T) { + // Step 1: currentUserWithGetent must resolve the running user. + current, err := currentUserWithGetent() + require.NoError(t, err, "currentUserWithGetent must resolve the running user") + require.NotEmpty(t, current.Uid) + require.NotEmpty(t, current.Username) + + // Step 2: lookupWithGetent by the same username must return matching identity. + byName, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID") + assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID") + assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home") + + // Step 3: groupIdsWithFallback must return at least the primary GID. + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "user must have at least one group") + + foundPrimary := false + for _, gid := range groups { + if runtime.GOOS != "windows" { + _, err := strconv.ParseUint(gid, 10, 32) + require.NoError(t, err, "group ID %q must be a valid uint32", gid) + } + if gid == current.Gid { + foundPrimary = true + } + } + assert.True(t, foundPrimary, "primary GID %s should appear in supplementary groups", current.Gid) + + // Step 4: getShellFromGetent should either return a valid shell path or empty + // (empty is OK when getent is not available, e.g. macOS without Homebrew getent). + if runtime.GOOS != "windows" { + shell := getShellFromGetent(current.Uid) + if shell != "" { + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + } + } +} + +// TestIntegration_LookupAndGroupsConsistency verifies that a user resolved via +// lookupWithGetent can have their groups resolved via groupIdsWithFallback, +// testing the handoff between the two functions as used by the SSH server. +func TestIntegration_LookupAndGroupsConsistency(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + // Simulate the SSH server flow: lookup user, then get their groups. + resolved, err := lookupWithGetent(current.Username) + require.NoError(t, err) + + groups, err := groupIdsWithFallback(resolved) + require.NoError(t, err) + require.NotEmpty(t, groups, "resolved user must have groups") + + // On Unix, all returned GIDs must be valid numeric values. + // On Windows, group IDs are SIDs (e.g., "S-1-5-32-544"). + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be numeric", gid) + } + } +} + +// TestIntegration_ShellLookupChain tests the full shell resolution chain +// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix. +func TestIntegration_ShellLookupChain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix shell lookup not applicable on Windows") + } + + current, err := user.Current() + require.NoError(t, err) + + // getUserShell is the top-level function used by the SSH server. + shell := getUserShell(current.Uid) + require.NotEmpty(t, shell, "getUserShell must always return a shell") + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} diff --git a/client/ssh/server/getent_unix.go b/client/ssh/server/getent_unix.go new file mode 100644 index 000000000..18edb2fdf --- /dev/null +++ b/client/ssh/server/getent_unix.go @@ -0,0 +1,122 @@ +//go:build !windows + +package server + +import ( + "context" + "fmt" + "os/exec" + "os/user" + "runtime" + "strings" + "time" +) + +const getentTimeout = 5 * time.Second + +// getShellFromGetent gets a user's login shell via getent by UID. +// This is needed even with CGO because getShellFromPasswd reads /etc/passwd +// directly and won't find NSS-provided users there. +func getShellFromGetent(userID string) string { + _, shell, err := runGetent(userID) + if err != nil { + return "" + } + return shell +} + +// runGetent executes `getent passwd ` and returns the user and login shell. +func runGetent(query string) (*user.User, string, error) { + if !validateGetentInput(query) { + return nil, "", fmt.Errorf("invalid getent input: %q", query) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "getent", "passwd", query).Output() + if err != nil { + return nil, "", fmt.Errorf("getent passwd %s: %w", query, err) + } + + return parseGetentPasswd(string(out)) +} + +// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell" +func parseGetentPasswd(output string) (*user.User, string, error) { + fields := strings.SplitN(strings.TrimSpace(output), ":", 8) + if len(fields) < 6 { + return nil, "", fmt.Errorf("unexpected getent output (need 6+ fields): %q", output) + } + + if fields[0] == "" || fields[2] == "" || fields[3] == "" { + return nil, "", fmt.Errorf("missing required fields in getent output: %q", output) + } + + var shell string + if len(fields) >= 7 { + shell = fields[6] + } + + return &user.User{ + Username: fields[0], + Uid: fields[2], + Gid: fields[3], + Name: fields[4], + HomeDir: fields[5], + }, shell, nil +} + +// validateGetentInput checks that the input is safe to pass to getent or id. +// Allows POSIX usernames, numeric UIDs, and common NSS extensions +// (@ for Kerberos, $ for Samba, + for NIS compat). +func validateGetentInput(input string) bool { + maxLen := 32 + if runtime.GOOS == "linux" { + maxLen = 256 + } + + if len(input) == 0 || len(input) > maxLen { + return false + } + + for _, r := range input { + if isAllowedGetentChar(r) { + continue + } + return false + } + return true +} + +func isAllowedGetentChar(r rune) bool { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' { + return true + } + switch r { + case '.', '_', '-', '@', '+', '$': + return true + } + return false +} + +// runIdGroups runs `id -G ` and returns the space-separated group IDs. +func runIdGroups(username string) ([]string, error) { + if !validateGetentInput(username) { + return nil, fmt.Errorf("invalid username for id command: %q", username) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "id", "-G", username).Output() + if err != nil { + return nil, fmt.Errorf("id -G %s: %w", username, err) + } + + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return nil, fmt.Errorf("id -G %s: empty output", username) + } + return strings.Fields(trimmed), nil +} diff --git a/client/ssh/server/getent_unix_test.go b/client/ssh/server/getent_unix_test.go new file mode 100644 index 000000000..e44563b79 --- /dev/null +++ b/client/ssh/server/getent_unix_test.go @@ -0,0 +1,410 @@ +//go:build !windows + +package server + +import ( + "os/exec" + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGetentPasswd(t *testing.T) { + tests := []struct { + name string + input string + wantUser *user.User + wantShell string + wantErr bool + errContains string + }{ + { + name: "standard entry", + input: "alice:x:1001:1001:Alice Smith:/home/alice:/bin/bash\n", + wantUser: &user.User{ + Username: "alice", + Uid: "1001", + Gid: "1001", + Name: "Alice Smith", + HomeDir: "/home/alice", + }, + wantShell: "/bin/bash", + }, + { + name: "root entry", + input: "root:x:0:0:root:/root:/bin/bash", + wantUser: &user.User{ + Username: "root", + Uid: "0", + Gid: "0", + Name: "root", + HomeDir: "/root", + }, + wantShell: "/bin/bash", + }, + { + name: "empty gecos field", + input: "svc:x:999:999::/var/lib/svc:/usr/sbin/nologin", + wantUser: &user.User{ + Username: "svc", + Uid: "999", + Gid: "999", + Name: "", + HomeDir: "/var/lib/svc", + }, + wantShell: "/usr/sbin/nologin", + }, + { + name: "gecos with commas", + input: "john:x:1002:1002:John Doe,Room 101,555-1234,555-4321:/home/john:/bin/zsh", + wantUser: &user.User{ + Username: "john", + Uid: "1002", + Gid: "1002", + Name: "John Doe,Room 101,555-1234,555-4321", + HomeDir: "/home/john", + }, + wantShell: "/bin/zsh", + }, + { + name: "remote user with large UID", + input: "remoteuser:*:50001:50001:Remote User:/home/remoteuser:/bin/bash\n", + wantUser: &user.User{ + Username: "remoteuser", + Uid: "50001", + Gid: "50001", + Name: "Remote User", + HomeDir: "/home/remoteuser", + }, + wantShell: "/bin/bash", + }, + { + name: "no shell field (only 6 fields)", + input: "minimal:x:1000:1000::/home/minimal", + wantUser: &user.User{ + Username: "minimal", + Uid: "1000", + Gid: "1000", + Name: "", + HomeDir: "/home/minimal", + }, + wantShell: "", + }, + { + name: "too few fields", + input: "bad:x:1000", + wantErr: true, + errContains: "need 6+ fields", + }, + { + name: "empty username", + input: ":x:1000:1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty UID", + input: "test:x::1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty GID", + input: "test:x:1000:::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty input", + input: "", + wantErr: true, + errContains: "need 6+ fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, shell, err := parseGetentPasswd(tt.input) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantUser.Username, u.Username, "username") + assert.Equal(t, tt.wantUser.Uid, u.Uid, "UID") + assert.Equal(t, tt.wantUser.Gid, u.Gid, "GID") + assert.Equal(t, tt.wantUser.Name, u.Name, "name/gecos") + assert.Equal(t, tt.wantUser.HomeDir, u.HomeDir, "home directory") + assert.Equal(t, tt.wantShell, shell, "shell") + }) + } +} + +func TestValidateGetentInput(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"normal username", "alice", true}, + {"numeric UID", "1001", true}, + {"dots and underscores", "alice.bob_test", true}, + {"hyphen", "alice-bob", true}, + {"kerberos principal", "user@REALM", true}, + {"samba machine account", "MACHINE$", true}, + {"NIS compat", "+user", true}, + {"empty", "", false}, + {"null byte", "alice\x00bob", false}, + {"newline", "alice\nbob", false}, + {"tab", "alice\tbob", false}, + {"control char", "alice\x01bob", false}, + {"DEL char", "alice\x7fbob", false}, + {"space rejected", "alice bob", false}, + {"semicolon rejected", "alice;bob", false}, + {"backtick rejected", "alice`bob", false}, + {"pipe rejected", "alice|bob", false}, + {"33 chars exceeds non-linux max", makeLongString(33), runtime.GOOS == "linux"}, + {"256 chars at linux max", makeLongString(256), runtime.GOOS == "linux"}, + {"257 chars exceeds all limits", makeLongString(257), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, validateGetentInput(tt.input)) + }) + } +} + +func makeLongString(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = 'a' + } + return string(b) +} + +func TestRunGetent_RootUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, shell, err := runGetent("root") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) + assert.Equal(t, "0", u.Gid) + assert.NotEmpty(t, shell, "root should have a shell") +} + +func TestRunGetent_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, _, err := runGetent("0") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) +} + +func TestRunGetent_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + _, _, err := runGetent("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunGetent_InvalidInput(t *testing.T) { + _, _, err := runGetent("") + assert.Error(t, err) + + _, _, err = runGetent("user\x00name") + assert.Error(t, err) +} + +func TestRunGetent_NotAvailable(t *testing.T) { + if _, err := exec.LookPath("getent"); err == nil { + t.Skip("getent is available, can't test missing case") + } + + _, _, err := runGetent("root") + assert.Error(t, err, "should fail when getent is not installed") +} + +func TestRunIdGroups_CurrentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + groups, err := runIdGroups(current.Username) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } +} + +func TestRunIdGroups_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + _, err := runIdGroups("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunIdGroups_InvalidInput(t *testing.T) { + _, err := runIdGroups("") + assert.Error(t, err) + + _, err = runIdGroups("user\x00name") + assert.Error(t, err) +} + +func TestGetentResultsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Username) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") + assert.Equal(t, current.Gid, getentUser.Gid, "GID should match") + assert.Equal(t, current.HomeDir, getentUser.HomeDir, "home directory should match") +} + +func TestGetentResultsMatchStdlib_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Uid) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") +} + +func TestIdGroupsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + stdGroups, err := current.GroupIds() + if err != nil { + t.Skip("os/user.GroupIds() not working, likely CGO_ENABLED=0") + } + + idGroups, err := runIdGroups(current.Username) + require.NoError(t, err) + + // Deduplicate both lists: id -G can return duplicates (e.g., root in Docker) + // and ElementsMatch treats duplicates as distinct. + assert.ElementsMatch(t, uniqueStrings(stdGroups), uniqueStrings(idGroups), "id -G should return same groups as os/user") +} + +func uniqueStrings(ss []string) []string { + seen := make(map[string]struct{}, len(ss)) + out := make([]string, 0, len(ss)) + for _, s := range ss { + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +// TestGetShellFromPasswd_CurrentUser verifies that getShellFromPasswd correctly +// reads the current user's shell from /etc/passwd by comparing it against what +// getent reports (which goes through NSS). +func TestGetShellFromPasswd_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + shell := getShellFromPasswd(current.Uid) + if shell == "" { + t.Skip("current user not found in /etc/passwd (may be an NSS-only user)") + } + + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + + if _, err := exec.LookPath("getent"); err == nil { + _, getentShell, getentErr := runGetent(current.Uid) + if getentErr == nil && getentShell != "" { + assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent") + } + } +} + +// TestGetShellFromPasswd_RootUser verifies that getShellFromPasswd can read +// root's shell from /etc/passwd. Root is guaranteed to be in /etc/passwd on +// any standard Unix system. +func TestGetShellFromPasswd_RootUser(t *testing.T) { + shell := getShellFromPasswd("0") + require.NotEmpty(t, shell, "root (UID 0) must be in /etc/passwd") + assert.True(t, shell[0] == '/', "root shell should be an absolute path, got %q", shell) +} + +// TestGetShellFromPasswd_NonexistentUID verifies that getShellFromPasswd +// returns empty for a UID that doesn't exist in /etc/passwd. +func TestGetShellFromPasswd_NonexistentUID(t *testing.T) { + shell := getShellFromPasswd("4294967294") + assert.Empty(t, shell, "nonexistent UID should return empty shell") +} + +// TestGetShellFromPasswd_MatchesGetentForKnownUsers reads /etc/passwd directly +// and cross-validates every entry against getent to ensure parseGetentPasswd +// and getShellFromPasswd agree on shell values. +func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available") + } + + // Pick a few well-known system UIDs that are virtually always in /etc/passwd. + uids := []string{"0"} // root + + current, err := user.Current() + require.NoError(t, err) + uids = append(uids, current.Uid) + + for _, uid := range uids { + passwdShell := getShellFromPasswd(uid) + if passwdShell == "" { + continue + } + + _, getentShell, err := runGetent(uid) + if err != nil { + continue + } + + assert.Equal(t, getentShell, passwdShell, "shell mismatch for UID %s", uid) + } +} diff --git a/client/ssh/server/getent_windows.go b/client/ssh/server/getent_windows.go new file mode 100644 index 000000000..3e76b3e8e --- /dev/null +++ b/client/ssh/server/getent_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package server + +import "os/user" + +// lookupWithGetent on Windows just delegates to os/user.Lookup. +// Windows does not use NSS/getent; its user lookup works without CGO. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent on Windows just delegates to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection. +func getShellFromGetent(_ string) string { + return "" +} + +// groupIdsWithFallback on Windows just delegates to u.GroupIds(). +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go index fea9d2910..1e8ff5e31 100644 --- a/client/ssh/server/shell.go +++ b/client/ssh/server/shell.go @@ -49,10 +49,14 @@ func getWindowsUserShell() string { return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` } -// getUnixUserShell returns the shell for Unix-like systems +// getUnixUserShell returns the shell for Unix-like systems. +// Tries /etc/passwd first (fast, no subprocess), falls back to getent for NSS users. func getUnixUserShell(userID string) string { - shell := getShellFromPasswd(userID) - if shell != "" { + if shell := getShellFromPasswd(userID); shell != "" { + return shell + } + + if shell := getShellFromGetent(userID); shell != "" { return shell } diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go index 799882cbb..bc2aa2d7d 100644 --- a/client/ssh/server/user_utils.go +++ b/client/ssh/server/user_utils.go @@ -23,8 +23,8 @@ func isPlatformUnix() bool { // Dependency injection variables for testing - allows mocking dynamic runtime checks var ( - getCurrentUser = user.Current - lookupUser = user.Lookup + getCurrentUser = currentUserWithGetent + lookupUser = lookupWithGetent getCurrentOS = func() string { return runtime.GOOS } getIsProcessPrivileged = isCurrentProcessPrivileged diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index d80b77042..220e2240f 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -146,32 +146,30 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u } gid := uint32(gid64) - groups, err := s.getSupplementaryGroups(localUser.Username) - if err != nil { - log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups, err := s.getSupplementaryGroups(localUser) + if err != nil || len(groups) == 0 { + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + } groups = []uint32{gid} } return uid, gid, groups, nil } -// getSupplementaryGroups retrieves supplementary group IDs for a user -func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { - u, err := user.Lookup(username) +// getSupplementaryGroups retrieves supplementary group IDs for a user. +// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds. +func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) { + groupIDStrings, err := groupIdsWithFallback(u) if err != nil { - return nil, fmt.Errorf("lookup user %s: %w", username, err) - } - - groupIDStrings, err := u.GroupIds() - if err != nil { - return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err) } groups := make([]uint32, len(groupIDStrings)) for i, gidStr := range groupIDStrings { gid64, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, u.Username, err) } groups[i] = uint32(gid64) } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7af00cd20..0574e53d0 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -34,7 +34,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" @@ -308,10 +307,11 @@ type serviceClient struct { sshJWTCacheTTL int connected bool - update *version.Update daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + isEnforcedUpdate bool + lastNotifiedVersion string settingsEnabled bool profilesEnabled bool showNetworks bool @@ -367,7 +367,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -828,7 +827,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log return nil } -func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -850,9 +849,7 @@ func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) e return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ - AutoUpdate: protobuf.Bool(wannaAutoUpdate), - }); err != nil { + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { return fmt.Errorf("start connection: %w", err) } @@ -933,13 +930,13 @@ func (s *serviceClient) updateStatus() error { systrayIconState = false } - // the updater struct notify by the upgrades available only, but if meanwhile the daemon has successfully - // updated must reset the mUpdate visibility state + // if the daemon version changed (e.g. after a successful update), reset the update indication if s.daemonVersion != status.DaemonVersion { - s.mUpdate.Hide() + if s.daemonVersion != "" { + s.mUpdate.Hide() + s.isUpdateIconActive = false + } s.daemonVersion = status.DaemonVersion - - s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion) if !s.isUpdateIconActive { if systrayIconState { systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected) @@ -1091,7 +1088,6 @@ func (s *serviceClient) onTrayReady() { // update exit node menu in case service is already connected go s.updateExitNodes() - s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon @@ -1135,6 +1131,13 @@ func (s *serviceClient) onTrayReady() { } } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + if newVersion, ok := event.Metadata["new_version_available"]; ok { + _, enforced := event.Metadata["enforced"] + log.Infof("received new_version_available event: version=%s enforced=%v", newVersion, enforced) + s.onUpdateAvailable(newVersion, enforced) + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) @@ -1507,10 +1510,18 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { return &config } -func (s *serviceClient) onUpdateAvailable() { +func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) { s.updateIndicationLock.Lock() defer s.updateIndicationLock.Unlock() + s.isEnforcedUpdate = enforced + if enforced { + s.mUpdate.SetTitle("Install version " + newVersion) + } else { + s.lastNotifiedVersion = "" + s.mUpdate.SetTitle("Download latest version") + } + s.mUpdate.Show() s.isUpdateIconActive = true @@ -1519,6 +1530,11 @@ func (s *serviceClient) onUpdateAvailable() { } else { systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected) } + + if enforced && s.lastNotifiedVersion != newVersion { + s.lastNotifiedVersion = newVersion + s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install")) + } } // onSessionExpire sends a notification to the user when the session expires. diff --git a/client/ui/event/event.go b/client/ui/event/event.go index 4d949416d..b8ed09a5c 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -107,12 +107,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { handlers := slices.Clone(e.handlers) e.mu.Unlock() - // critical events are always shown - if !enabled && event.Severity != proto.SystemEvent_CRITICAL { - return - } - - if event.UserMessage != "" { + if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) { title := e.getEventTitle(event) body := event.UserMessage id := event.Metadata["id"] diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 6adf8778c..60a580dae 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -82,7 +82,7 @@ func (h *eventHandler) handleConnectClick() { go func() { defer connectCancel() - if err := h.client.menuUpClick(connectCtx, true); err != nil { + if err := h.client.menuUpClick(connectCtx); err != nil { st, ok := status.FromError(err) if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") @@ -211,9 +211,42 @@ func (h *eventHandler) handleGitHubClick() { } func (h *eventHandler) handleUpdateClick() { - if err := openURL(version.DownloadUrl()); err != nil { - log.Errorf("failed to open download URL: %v", err) + h.client.updateIndicationLock.Lock() + enforced := h.client.isEnforcedUpdate + h.client.updateIndicationLock.Unlock() + + if !enforced { + if err := openURL(version.DownloadUrl()); err != nil { + log.Errorf("failed to open download URL: %v", err) + } + return } + + // prevent blocking against a busy server + h.client.mUpdate.Disable() + go func() { + defer h.client.mUpdate.Enable() + conn, err := h.client.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get service client for update: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + + resp, err := conn.TriggerUpdate(h.client.ctx, &proto.TriggerUpdateRequest{}) + if err != nil { + log.Errorf("TriggerUpdate failed: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + if !resp.Success { + log.Errorf("TriggerUpdate failed: %s", resp.ErrorMsg) + _ = openURL(version.DownloadUrl()) + return + } + + log.Infof("update triggered via daemon") + }() } func (h *eventHandler) handleNetworksClick() { diff --git a/client/ui/profile.go b/client/ui/profile.go index a38d8918a..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -397,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -411,7 +411,7 @@ type newProfileMenuArgs struct { profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -579,7 +579,7 @@ func (p *profileMenu) refresh() { connectCtx, connectCancel := context.WithCancel(p.ctx) p.serviceClient.connectCancel = connectCancel - if err := p.upClickCallback(connectCtx, false); err != nil { + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go index 76440d684..bf47ac434 100644 --- a/client/ui/quickactions.go +++ b/client/ui/quickactions.go @@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() { connCmd := connectCommand{ connectClient: func() error { - return s.menuUpClick(s.ctx, false) + return s.menuUpClick(s.ctx) }, } diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 7fd87ee8e..70088d66a 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,44 +182,6 @@ read_enable_proxy() { return 0 } -read_proxy_domain() { - local suggested_proxy="proxy.${BASE_DOMAIN}" - - echo "" > /dev/stderr - echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr - echo "to avoid TLS certificate conflicts." > /dev/stderr - echo "" > /dev/stderr - echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr - echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr - echo "" > /dev/stderr - echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr - read -r READ_PROXY_DOMAIN < /dev/tty - - if [[ -z "$READ_PROXY_DOMAIN" ]]; then - echo "The proxy domain cannot be empty." > /dev/stderr - read_proxy_domain - return - fi - - if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null - if [[ $? -eq 0 ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo "$READ_PROXY_DOMAIN" - return 0 -} - read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -334,7 +296,6 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" - PROXY_DOMAIN="" PROXY_TOKEN="" return 0 } @@ -364,9 +325,6 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) - if [[ "$ENABLE_PROXY" == "true" ]]; then - PROXY_DOMAIN=$(read_proxy_domain) - fi fi # Handle external Traefik-specific prompts (option 1) @@ -813,7 +771,7 @@ NB_PROXY_MANAGEMENT_ADDRESS=http://netbird-server:80 # Allow insecure gRPC connection to management (required for internal Docker network) NB_PROXY_ALLOW_INSECURE=true # Public URL where this proxy is reachable (used for cluster registration) -NB_PROXY_DOMAIN=$PROXY_DOMAIN +NB_PROXY_DOMAIN=$NETBIRD_DOMAIN NB_PROXY_ADDRESS=:8443 NB_PROXY_TOKEN=$PROXY_TOKEN NB_PROXY_CERTIFICATE_DIRECTORY=/certs @@ -1203,8 +1161,7 @@ print_builtin_traefik_instructions() { echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge." echo " Point your proxy domain to this server's domain address like in the examples below:" echo "" - echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" - echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" + echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN" echo "" fi return 0 diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 121c55ac5..4b414df6f 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -87,9 +87,14 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App newNetworkMapBuilder = false } - compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err) + compactedNetworkMap := true + compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) + parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) + if err != nil && len(compactedEnv) > 0 { + log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) + } + if err == nil && !parsedCompactedNmap { + log.WithContext(ctx).Info("disabling compacted mode") compactedNetworkMap = false } diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 2f796a5d1..7cb0f3908 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -210,7 +210,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee }, } - _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false) + _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true) if err != nil { return fmt.Errorf("failed to create proxy peer: %w", err) } diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 0bcc59b68..619a34684 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -10,6 +10,15 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// AccessLogProtocol identifies the transport protocol of an access log entry. +type AccessLogProtocol string + +const ( + AccessLogProtocolHTTP AccessLogProtocol = "http" + AccessLogProtocolTCP AccessLogProtocol = "tcp" + AccessLogProtocolUDP AccessLogProtocol = "udp" +) + type AccessLogEntry struct { ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` @@ -22,10 +31,11 @@ type AccessLogEntry struct { Duration time.Duration `gorm:"index"` StatusCode int `gorm:"index"` Reason string - UserId string `gorm:"index"` - AuthMethodUsed string `gorm:"index"` - BytesUpload int64 `gorm:"index"` - BytesDownload int64 `gorm:"index"` + UserId string `gorm:"index"` + AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` + Protocol AccessLogProtocol `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -43,17 +53,22 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.AccountID = serviceLog.GetAccountId() a.BytesUpload = serviceLog.GetBytesUpload() a.BytesDownload = serviceLog.GetBytesDownload() + a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { - if ip, err := netip.ParseAddr(sourceIP); err == nil { - a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice()) + if addr, err := netip.ParseAddr(sourceIP); err == nil { + addr = addr.Unmap() + a.GeoLocation.ConnectionIP = net.IP(addr.AsSlice()) } } - if !serviceLog.GetAuthSuccess() { - a.Reason = "Authentication failed" - } else if serviceLog.GetResponseCode() >= 400 { - a.Reason = "Request failed" + // Only set reason for HTTP entries. L4 entries have no auth or status code. + if a.Protocol == "" || a.Protocol == AccessLogProtocolHTTP { + if !serviceLog.GetAuthSuccess() { + a.Reason = "Authentication failed" + } else if serviceLog.GetResponseCode() >= 400 { + a.Reason = "Request failed" + } } } @@ -90,6 +105,12 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { cityName = &a.GeoLocation.CityName } + var protocol *string + if a.Protocol != "" { + p := string(a.Protocol) + protocol = &p + } + return &api.ProxyAccessLog{ Id: a.ID, ServiceId: a.ServiceID, @@ -107,5 +128,6 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { CityName: cityName, BytesUpload: a.BytesUpload, BytesDownload: a.BytesDownload, + Protocol: protocol, } } diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index 83fd669af..861d026a7 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -14,6 +14,9 @@ type Domain struct { TargetCluster string // The proxy cluster this domain should be validated against Type Type `gorm:"-"` Validated bool + // SupportsCustomPorts is populated at query time for free domains from the + // proxy cluster capabilities. Not persisted. + SupportsCustomPorts *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 2fbcdd5b8..d26a6a418 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -42,10 +42,11 @@ func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainToApi(d *domain.Domain) api.ReverseProxyDomain { resp := api.ReverseProxyDomain{ - Domain: d.Domain, - Id: d.ID, - Type: domainTypeToApi(d.Type), - Validated: d.Validated, + Domain: d.Domain, + Id: d.ID, + Type: domainTypeToApi(d.Type), + Validated: d.Validated, + SupportsCustomPorts: d.SupportsCustomPorts, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 8bbc98726..813027ea2 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -33,11 +33,16 @@ type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) } +type clusterCapabilities interface { + ClusterSupportsCustomPorts(clusterAddr string) *bool +} + type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + clusterCapabilities clusterCapabilities + permissionsManager permissions.Manager accountManager account.Manager } @@ -51,6 +56,11 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } +// SetClusterCapabilities sets the cluster capabilities provider for domain queries. +func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { + m.clusterCapabilities = caps +} + func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -80,24 +90,32 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d }).Debug("getting domains with proxy allow list") for _, cluster := range allowList { - ret = append(ret, &domain.Domain{ + d := &domain.Domain{ Domain: cluster, AccountID: accountID, Type: domain.TypeFree, Validated: true, - }) + } + if m.clusterCapabilities != nil { + d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) + } + ret = append(ret, d) } // Add custom domains. for _, d := range domains { - ret = append(ret, &domain.Domain{ + cd := &domain.Domain{ ID: d.ID, Domain: d.Domain, AccountID: accountID, TargetCluster: d.TargetCluster, Type: domain.TypeCustom, Validated: d.Validated, - }) + } + if m.clusterCapabilities != nil && d.TargetCluster != "" { + cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + } + ret = append(ret, cd) } return ret, nil @@ -298,7 +316,7 @@ func extractClusterFromCustomDomains(domain string, customDomains []*domain.Doma // It matches the domain suffix against available clusters and returns the matching cluster. func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { for _, cluster := range availableClusters { - if strings.HasSuffix(domain, "."+cluster) { + if domain == cluster || strings.HasSuffix(domain, "."+cluster) { return cluster, true } } diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 15f2f9f54..67a8e74fa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -33,4 +33,5 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string + ClusterSupportsCustomPorts(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index e5b3e9886..acb49c45b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,6 +72,11 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } +// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. +func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) +} + // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index d9645ba88..b07a21122 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -144,6 +144,20 @@ func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) } +// ClusterSupportsCustomPorts mocks base method. +func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) +} + // GetProxiesForCluster mocks base method. func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index b420f22a8..39fd7e3ae 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -22,7 +22,7 @@ type Manager interface { GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) - RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error - StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error + RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 727b2c7de..bdc1f3e65 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter } // RenewServiceFromPeer mocks base method. -func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer. -func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID) } // SetCertificateIssuedAt mocks base method. @@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca } // StopServiceFromPeer mocks base method. -func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // StopServiceFromPeer indicates an expected call of StopServiceFromPeer. -func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID) } // UpdateService mocks base method. diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index f28b633b8..c53219d2e 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -11,19 +11,22 @@ import ( domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) type handler struct { - manager rpservice.Manager + manager rpservice.Manager + permissionsManager permissions.Manager } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) { h := &handler{ - manager: manager, + manager: manager, + permissionsManager: permissionsManager, } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index c831b4a22..6ff8343b9 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) { // Create a non-expired service resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8081, - Protocol: "http", + Port: 8081, + Mode: "http", }) require.NoError(t, err) @@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Delete the service before reaping - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) // Reaping should handle the already-deleted service gracefully @@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) { for i := range 5 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) { t.Run("renew succeeds for active service", func(t *testing.T) { resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8082, - Protocol: "http", + Port: 8082, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, lookupErr) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) }) t.Run("renew fails for nonexistent domain", func(t *testing.T) { - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) assert.Contains(t, err.Error(), "no active expose session") }) @@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) { assert.Equal(t, int64(0), count) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8083, - Protocol: "http", + Port: 8083, + Mode: "http", }) require.NoError(t, err) @@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) { for i := range maxExposesPerPeer { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8090 + i, - Protocol: "http", + Port: uint16(8090 + i), + Mode: "http", }) require.NoError(t, err, "expose %d should succeed", i) } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9999, - Protocol: "http", + Port: 9999, + Mode: "http", }) require.Error(t, err) assert.Contains(t, err.Error(), "maximum number of active expose sessions") @@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8086, - Protocol: "http", + Port: 8086, + Mode: "http", }) require.NoError(t, err) @@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) { expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Renew it before the reaper runs - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) // Reaper should skip it because the re-check sees a fresh timestamp @@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) { require.NoError(t, err, "renewed service should survive reaping") } +// resolveServiceIDByDomain looks up a service ID by domain in tests. +func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), domain) + require.NoError(t, err) + return svc.ID +} + // expireEphemeralService backdates meta_last_renewed_at to force expiration. func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { t.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go new file mode 100644 index 000000000..c7a61ddcf --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -0,0 +1,582 @@ +package manager + +import ( + "context" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +const testCluster = "test-cluster" + +func boolPtr(v bool) *bool { return &v } + +// setupL4Test creates a manager with a mock proxy controller for L4 port tests. +func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) { + t.Helper() + + ctrl := gomock.NewController(t) + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + CreatedBy: testUserID, + Settings: &types.Settings{ + PeerExposeEnabled: true, + PeerExposeGroups: []string{testGroupID}, + }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, + Peers: map[string]*nbpeer.Peer{ + testPeerID: { + ID: testPeerID, + AccountID: testAccountID, + Key: "test-key", + DNSLabel: "test-peer", + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, + }, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + AccountID: testAccountID, + Name: "Expose Group", + }, + }, + }) + require.NoError(t, err) + + err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) + require.NoError(t, err) + + mockCtrl := proxy.NewMockController(ctrl) + mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() + mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + + accountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + }, + } + + mgr := &Manager{ + store: testStore, + accountManager: accountMgr, + permissionsManager: permissions.NewManager(testStore), + proxyController: mockCtrl, + clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + + return mgr, testStore, mockCtrl +} + +// seedService creates a service directly in the store for test setup. +func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service { + t.Helper() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: name, + Mode: protocol, + Domain: domain, + ProxyCluster: cluster, + ListenPort: port, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + err := s.CreateService(context.Background(), svc) + require.NoError(t, err) + return svc +} + +func TestPortConflict_TCPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-tcp", + Mode: "tcp", + Domain: "conflicting-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP+TCP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_UDPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-udp", + Mode: "udp", + Domain: "conflicting-udp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "UDP+UDP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tls", + Mode: "tls", + Domain: "app2.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)") +} + +func TestPortConflict_TLSSamePortSameDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "duplicate-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TLS+TLS on same domain should be rejected") + assert.Contains(t, err.Error(), "domain already taken") +} + +func TestPortConflict_TLSAndTCPSamePort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)") +} + +func TestAutoAssign_TCPNoListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax) + assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set") +} + +func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it") + assert.Contains(t, err.Error(), "custom ports") +} + +func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability") + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS") +} + +func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden") + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range", svc.ListenPort) + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_AvoidsExistingPorts(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existingPort := uint16(20000) + seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: "auto-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing") + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported") + assert.False(t, svc.PortAutoAssigned) +} + +func TestUpdate_PreservesExistingListenPort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc-renamed", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0") +} + +func TestUpdate_AllowsPortChange(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 54321, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied") +} + +func TestCreateServiceFromPeer_TCP(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + assert.NotEmpty(t, resp.ServiceName) + assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain") + assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports") + assert.Contains(t, resp.ServiceURL, "tcp://") +} + +func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + ListenPort: 15432, + }) + require.NoError(t, err) + + assert.False(t, resp.PortAutoAssigned) + assert.Contains(t, resp.ServiceURL, ":15432") +} + +func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + // When no explicit listen port, defaults to target port + assert.Contains(t, resp.ServiceURL, ":5432") + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TLS(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 443, + Mode: "tls", + }) + require.NoError(t, err) + + assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain") + assert.Contains(t, resp.ServiceURL, "tls://") + assert.Contains(t, resp.ServiceURL, ":443") + // TLS always keeps its port (not port-based protocol for auto-assign) + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + }) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + // Renew after stop should fail + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.Error(t, err) +} + +func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + Pin: "123456", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 56a1fc98a..c40961fdc 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "math/rand/v2" + "os" "slices" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -23,6 +25,45 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + defaultAutoAssignPortMin uint16 = 10000 + defaultAutoAssignPortMax uint16 = 49151 + + // EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN" + // EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX" +) + +var ( + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax +) + +func init() { + autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin) + autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax) + if autoAssignPortMin > autoAssignPortMax { + log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults", + EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax) + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax + } +} + +func portFromEnv(key string, fallback uint16) uint16 { + val := os.Getenv(key) + if val == "" { + return fallback + } + n, err := strconv.ParseUint(val, 10, 16) + if err != nil { + log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err) + return fallback + } + return uint16(n) +} + const unknownHostPlaceholder = "unknown" // ClusterDeriver derives the proxy cluster from a domain. @@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s * return fmt.Errorf("unknown target type: %s", target.TargetType) } } + return nil } @@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return nil } -func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil { + if svc.Domain != "" { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - - if err := transaction.CreateService(ctx, service); err != nil { - return fmt.Errorf("failed to create service: %w", err) - } - - return nil - }) -} - -// persistNewEphemeralService creates an ephemeral service inside a single transaction -// that also enforces the duplicate and per-peer limit checks atomically. -// The count and exists queries use FOR UPDATE locking to serialize concurrent creates -// for the same peer, preventing the per-peer limit from being bypassed. -func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - // Lock the peer row to serialize concurrent creates for the same peer. - // Without this, when no ephemeral rows exist yet, FOR UPDATE on the services - // table returns no rows and acquires no locks, allowing concurrent inserts - // to bypass the per-peer limit. - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { - return fmt.Errorf("lock peer row: %w", err) - } - - exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) - if err != nil { - return fmt.Errorf("check existing expose: %w", err) - } - if exists { - return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") - } - - count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) - if err != nil { - return fmt.Errorf("count peer exposes: %w", err) - } - if count >= int64(maxExposesPerPeer) { - return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) - } - - if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { return err } @@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee }) } +// ensureL4Port auto-assigns a listen port when needed and validates cluster support. +func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) + if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { + if svc.Source != service.SourceEphemeral { + return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) + } + svc.ListenPort = 0 + } + if svc.ListenPort == 0 { + port, err := m.assignPort(ctx, tx, svc.ProxyCluster) + if err != nil { + return err + } + svc.ListenPort = port + svc.PortAutoAssigned = true + } + return nil +} + +// checkPortConflict rejects L4 services that would conflict on the same listener. +// For TCP/UDP: unique per cluster+protocol+port. +// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports). +// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked: +// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener. +func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 { + return nil + } + + existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort) + if err != nil { + return fmt.Errorf("query port conflicts: %w", err) + } + for _, s := range existing { + if s.ID == svc.ID { + continue + } + // TLS services on the same port are allowed if they have different domains (SNI routing) + if svc.Mode == service.ModeTLS && s.Domain != svc.Domain { + continue + } + return status.Errorf(status.AlreadyExists, + "%s port %d is already in use by service %q on cluster %s", + svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster) + } + + return nil +} + +// assignPort picks a random available port on the cluster within the auto-assign range. +func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) { + services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster) + if err != nil { + return 0, fmt.Errorf("query cluster ports: %w", err) + } + + occupied := make(map[uint16]struct{}, len(services)) + for _, s := range services { + if s.ListenPort > 0 { + occupied[s.ListenPort] = struct{}{} + } + } + + portRange := int(autoAssignPortMax-autoAssignPortMin) + 1 + for range 100 { + port := autoAssignPortMin + uint16(rand.IntN(portRange)) + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + for port := autoAssignPortMin; port <= autoAssignPortMax; port++ { + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster) +} + +// persistNewEphemeralService creates an ephemeral service inside a single transaction +// that also enforces the duplicate and per-peer limit checks atomically. +// The count and exists queries use FOR UPDATE locking to serialize concurrent creates +// for the same peer, preventing the per-peer limit from being bypassed. +func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { + return err + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { + return err + } + + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + +func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error { + // Lock the peer row to serialize concurrent creates for the same peer. + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { + return fmt.Errorf("lock peer row: %w", err) + } + + exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) + if err != nil { + return fmt.Errorf("check existing expose: %w", err) + } + if exists { + return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + + count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return fmt.Errorf("count peer exposes: %w", err) + } + if count >= int64(maxExposesPerPeer) { + return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + + return nil +} + +// checkDomainAvailable checks that no other service already uses this domain. func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing service: %w", err) + return fmt.Errorf("check existing service: %w", err) } return nil } @@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return err } + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } + updateInfo.oldCluster = existingService.ProxyCluster updateInfo.domainChanged = existingService.Domain != service.Domain @@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se m.preserveExistingAuthSecrets(service, existingService) m.preserveServiceMetadata(service, existingService) + m.preserveListenPort(service, existingService) updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled + if err := m.ensureL4Port(ctx, transaction, service); err != nil { + return err + } + if err := m.checkPortConflict(ctx, transaction, service); err != nil { + return err + } if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { return err } - if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return &updateInfo, err } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain) + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) } else { - service.ProxyCluster = newCluster + svc.ProxyCluster = newCluster } } return nil } +// validateProtocolChange rejects mode changes on update. +// Only empty<->HTTP is allowed; all other transitions are rejected. +func validateProtocolChange(oldMode, newMode string) error { + if newMode == "" || newMode == oldMode { + return nil + } + if isHTTPFamily(oldMode) && isHTTPFamily(newMode) { + return nil + } + return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode) +} + +func isHTTPFamily(mode string) bool { + return mode == "" || mode == "http" +} + func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && @@ -388,11 +564,18 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv service.SessionPublicKey = existingService.SessionPublicKey } +func (m *Manager) preserveListenPort(svc, existing *service.Service) { + if existing.ListenPort > 0 && svc.ListenPort == 0 { + svc.ListenPort = existing.ListenPort + svc.PortAutoAssigned = existing.PortAutoAssigned + } +} + func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { oidcCfg := m.proxyController.GetOIDCValidationConfig() switch { - case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster: + case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster: m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) case !s.Enabled && updateInfo.serviceEnabledChanged: @@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group") } +func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) { + return m.buildRandomDomain(serviceName) +} + // CreateServiceFromPeer creates a service initiated by a peer expose request. // It validates the request, checks expose permissions, enforces the per-peer limit, // creates the service, and tracks it for TTL-based reaping. @@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s svc.Source = service.SourceEphemeral if svc.Domain == "" { - domain, err := m.buildRandomDomain(svc.Name) + domain, err := m.resolveDefaultDomain(svc.Name) if err != nil { - return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err) + return nil, err } svc.Domain = domain } @@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) + serviceURL := "https://" + svc.Domain + if service.IsL4Protocol(svc.Mode) { + serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + } + return &service.ExposeServiceResponse{ - ServiceName: svc.Name, - ServiceURL: "https://" + svc.Domain, - Domain: svc.Domain, + ServiceName: svc.Name, + ServiceURL: serviceURL, + Domain: svc.Domain, + PortAutoAssigned: svc.PortAutoAssigned, }, nil } @@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr return groupIDs, nil } -func (m *Manager) buildRandomDomain(name string) (string, error) { +func (m *Manager) getDefaultClusterDomain() (string, error) { if m.clusterDeriver == nil { - return "", fmt.Errorf("unable to get random domain") + return "", fmt.Errorf("unable to get cluster domain") } clusterDomains := m.clusterDeriver.GetClusterDomains() if len(clusterDomains) == 0 { - return "", fmt.Errorf("no cluster domains found for service %s", name) + return "", fmt.Errorf("no cluster domains available") } - index := rand.IntN(len(clusterDomains)) - domain := name + "." + clusterDomains[index] - return domain, nil + return clusterDomains[rand.IntN(len(clusterDomains))], nil +} + +func (m *Manager) buildRandomDomain(name string) (string, error) { + domain, err := m.getDefaultClusterDomain() + if err != nil { + return "", err + } + return name + "." + domain, nil } // RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service. -func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - return m.store.RenewEphemeralService(ctx, accountID, peerID, domain) +func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID) } // StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB. -func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { - log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) +func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err) return err } return nil } -// deleteServiceFromPeer deletes a peer-initiated service identified by domain. +// deleteServiceFromPeer deletes a peer-initiated service identified by service ID. // When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. -func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { - svc, err := m.lookupPeerService(ctx, accountID, peerID, domain) - if err != nil { - return err - } - +func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error { activityCode := activity.PeerServiceUnexposed if expired { activityCode = activity.PeerServiceExposeExpired } - return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode) -} - -// lookupPeerService finds a peer-initiated service by domain and validates ownership. -func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { - svc, err := m.store.GetServiceByDomain(ctx, domain) - if err != nil { - return nil, err - } - - if svc.Source != service.SourceEphemeral { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") - } - - if svc.SourcePeer != peerID { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") - } - - return svc, nil + return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode) } func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index ba4e1c805..d23c91017 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 80, - Protocol: "http", - Domain: "example.com", + Port: 80, + Mode: "http", + Domain: "example.com", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) { require.NoError(t, err) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 0, - Protocol: "http", + Port: 0, + Mode: "http", } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) { }{ { name: "valid http request", - req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"}, wantErr: "", }, { - name: "valid https request with pin", - req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, - wantErr: "", + name: "https mode rejected", + req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"}, + wantErr: "unsupported mode", }, { name: "port zero rejected", - req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"}, wantErr: "port must be between 1 and 65535", }, { - name: "negative port rejected", - req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "port above 65535 rejected", - req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "unsupported protocol", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, - wantErr: "unsupported protocol", + name: "unsupported mode", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"}, + wantErr: "unsupported mode", }, { name: "invalid pin format", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"}, wantErr: "invalid pin", }, { name: "pin too short", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"}, wantErr: "invalid pin", }, { name: "valid 6-digit pin", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"}, wantErr: "", }, { name: "empty user group name", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}}, wantErr: "user group name cannot be empty", }, { name: "invalid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"}, wantErr: "invalid name prefix", }, { name: "valid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"}, wantErr: "", }, } @@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { // First create a service req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - // Delete by domain using unexported method - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false) require.NoError(t, err) // Verify service is deleted @@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { }) t.Run("expire uses correct activity", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true) require.NoError(t, err) }) } @@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) _, err = testStore.GetServiceByDomain(ctx, resp.Domain) @@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete") _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9090, - Protocol: "http", + Port: 9090, + Mode: "http", }) assert.NoError(t, err, "new expose should succeed after API delete") } @@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) { for i := range 3 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) { ctx := context.Background() t.Run("renews tracked expose", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) }) t.Run("fails for untracked domain", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) }) } @@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) assert.Len(t, targets, 0, "All targets should be deleted when service is deleted") } + +func TestValidateProtocolChange(t *testing.T) { + tests := []struct { + name string + oldP string + newP string + wantErr bool + }{ + {"empty to http", "", "http", false}, + {"http to http", "http", "http", false}, + {"same protocol", "tcp", "tcp", false}, + {"empty new proto", "tcp", "", false}, + {"http to tcp", "http", "tcp", true}, + {"tcp to udp", "tcp", "udp", true}, + {"tls to http", "tls", "http", true}, + {"udp to tls", "udp", "tls", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProtocolChange(tt.oldP, tt.newP) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot change mode") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index bfad7fe9a..623284404 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -34,6 +34,7 @@ const ( ) type Status string +type TargetType string const ( StatusPending Status = "pending" @@ -43,34 +44,36 @@ const ( StatusCertificateFailed Status = "certificate_failed" StatusError Status = "error" - TargetTypePeer = "peer" - TargetTypeHost = "host" - TargetTypeDomain = "domain" - TargetTypeSubnet = "subnet" + TargetTypePeer TargetType = "peer" + TargetTypeHost TargetType = "host" + TargetTypeDomain TargetType = "domain" + TargetTypeSubnet TargetType = "subnet" SourcePermanent = "permanent" SourceEphemeral = "ephemeral" ) type TargetOptions struct { - SkipTLSVerify bool `json:"skip_tls_verify"` - RequestTimeout time.Duration `json:"request_timeout,omitempty"` - PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` - CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` + SkipTLSVerify bool `json:"skip_tls_verify"` + RequestTimeout time.Duration `json:"request_timeout,omitempty"` + SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` + PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` } type Target struct { - ID uint `gorm:"primaryKey" json:"-"` - AccountID string `gorm:"index:idx_target_account;not null" json:"-"` - ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` - Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored - Port int `gorm:"index:idx_target_port" json:"port"` - Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` - TargetId string `gorm:"index:idx_target_id" json:"target_id"` - TargetType string `gorm:"index:idx_target_type" json:"target_type"` - Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` - Options TargetOptions `gorm:"embedded" json:"options"` + ID uint `gorm:"primaryKey" json:"-"` + AccountID string `gorm:"index:idx_target_account;not null" json:"-"` + ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` + Path *string `json:"path,omitempty"` + Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Port uint16 `gorm:"index:idx_target_port" json:"port"` + Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` + TargetId string `gorm:"index:idx_target_id" json:"target_id"` + TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"` + Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` + Options TargetOptions `gorm:"embedded" json:"options"` + ProxyProtocol bool `json:"proxy_protocol"` } type PasswordAuthConfig struct { @@ -146,23 +149,10 @@ type Service struct { SessionPublicKey string `gorm:"column:session_public_key"` Source string `gorm:"default:'permanent';index:idx_service_source_peer"` SourcePeer string `gorm:"index:idx_service_source_peer"` -} - -func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { - for _, target := range targets { - target.AccountID = accountID - } - - s := &Service{ - AccountID: accountID, - Name: name, - Domain: domain, - ProxyCluster: proxyCluster, - Targets: targets, - Enabled: enabled, - } - s.InitNewRecord() - return s + // Mode determines the service type: "http", "tcp", "udp", or "tls". + Mode string `gorm:"default:'http'"` + ListenPort uint16 + PortAutoAssigned bool } // InitNewRecord generates a new unique ID and resets metadata for a newly created @@ -177,21 +167,17 @@ func (s *Service) InitNewRecord() { } func (s *Service) ToAPIResponse() *api.Service { - s.Auth.ClearSecrets() - authConfig := api.ServiceAuthConfig{} if s.Auth.PasswordAuth != nil { authConfig.PasswordAuth = &api.PasswordAuthConfig{ - Enabled: s.Auth.PasswordAuth.Enabled, - Password: s.Auth.PasswordAuth.Password, + Enabled: s.Auth.PasswordAuth.Enabled, } } if s.Auth.PinAuth != nil { authConfig.PinAuth = &api.PINAuthConfig{ Enabled: s.Auth.PinAuth.Enabled, - Pin: s.Auth.PinAuth.Pin, } } @@ -208,13 +194,18 @@ func (s *Service) ToAPIResponse() *api.Service { st := api.ServiceTarget{ Path: target.Path, Host: &target.Host, - Port: target.Port, + Port: int(target.Port), Protocol: api.ServiceTargetProtocol(target.Protocol), TargetId: target.TargetId, TargetType: api.ServiceTargetTargetType(target.TargetType), Enabled: target.Enabled, } - st.Options = targetOptionsToAPI(target.Options) + opts := targetOptionsToAPI(target.Options) + if opts == nil { + opts = &api.ServiceTargetOptions{} + } + opts.ProxyProtocol = &target.ProxyProtocol + st.Options = opts apiTargets = append(apiTargets, st) } @@ -227,6 +218,9 @@ func (s *Service) ToAPIResponse() *api.Service { meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt } + mode := api.ServiceMode(s.Mode) + listenPort := int(s.ListenPort) + resp := &api.Service{ Id: s.ID, Name: s.Name, @@ -237,6 +231,9 @@ func (s *Service) ToAPIResponse() *api.Service { RewriteRedirects: &s.RewriteRedirects, Auth: authConfig, Meta: meta, + Mode: &mode, + ListenPort: &listenPort, + PortAutoAssigned: &s.PortAutoAssigned, } if s.ProxyCluster != "" { @@ -247,37 +244,7 @@ func (s *Service) ToAPIResponse() *api.Service { } func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { - pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) - for _, target := range s.Targets { - if !target.Enabled { - continue - } - - // TODO: Make path prefix stripping configurable per-target. - // Currently the matching prefix is baked into the target URL path, - // so the proxy strips-then-re-adds it (effectively a no-op). - targetURL := url.URL{ - Scheme: target.Protocol, - Host: target.Host, - Path: "/", // TODO: support service path - } - if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port)) - } - - path := "/" - if target.Path != nil { - path = *target.Path - } - - pm := &proto.PathMapping{ - Path: path, - Target: targetURL.String(), - } - - pm.Options = targetOptionsToProto(target.Options) - pathMappings = append(pathMappings, pm) - } + pathMappings := s.buildPathMappings() auth := &proto.Authentication{ SessionKey: s.SessionPublicKey, @@ -306,9 +273,58 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf AccountId: s.AccountID, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, + Mode: s.Mode, + ListenPort: int32(s.ListenPort), //nolint:gosec } } +// buildPathMappings constructs PathMapping entries from targets. +// For HTTP/HTTPS, each target becomes a path-based route with a full URL. +// For L4/TLS, a single target maps to a host:port address. +func (s *Service) buildPathMappings() []*proto.PathMapping { + pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) + for _, target := range s.Targets { + if !target.Enabled { + continue + } + + if IsL4Protocol(s.Mode) { + pm := &proto.PathMapping{ + Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)), + } + opts := l4TargetOptionsToProto(target) + if opts != nil { + pm.Options = opts + } + pathMappings = append(pathMappings, pm) + continue + } + + // HTTP/HTTPS: build full URL + targetURL := url.URL{ + Scheme: target.Protocol, + Host: target.Host, + Path: "/", + } + if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { + targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + } + + path := "/" + if target.Path != nil { + path = *target.Path + } + + pm := &proto.PathMapping{ + Path: path, + Target: targetURL.String(), + } + pm.Options = targetOptionsToProto(target.Options) + pathMappings = append(pathMappings, pm) + } + return pathMappings +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: @@ -325,8 +341,8 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { // isDefaultPort reports whether port is the standard default for the given scheme // (443 for https, 80 for http). -func isDefaultPort(scheme string, port int) bool { - return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) +func isDefaultPort(scheme string, port uint16) bool { + return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80) } // PathRewriteMode controls how the request path is rewritten before forwarding. @@ -346,7 +362,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode { } func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { - if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { + if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { return nil } apiOpts := &api.ServiceTargetOptions{} @@ -357,6 +373,10 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { s := opts.RequestTimeout.String() apiOpts.RequestTimeout = &s } + if opts.SessionIdleTimeout != 0 { + s := opts.SessionIdleTimeout.String() + apiOpts.SessionIdleTimeout = &s + } if opts.PathRewrite != "" { pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite) apiOpts.PathRewrite = &pr @@ -382,6 +402,23 @@ func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { return popts } +// l4TargetOptionsToProto converts L4-relevant target options to proto. +func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions { + if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 { + return nil + } + opts := &proto.PathTargetOptions{ + ProxyProtocol: target.ProxyProtocol, + } + if target.Options.RequestTimeout > 0 { + opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout) + } + if target.Options.SessionIdleTimeout > 0 { + opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout) + } + return opts +} + func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) { var opts TargetOptions if o.SkipTlsVerify != nil { @@ -394,6 +431,13 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, } opts.RequestTimeout = d } + if o.SessionIdleTimeout != nil { + d, err := time.ParseDuration(*o.SessionIdleTimeout) + if err != nil { + return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err) + } + opts.SessionIdleTimeout = d + } if o.PathRewrite != nil { opts.PathRewrite = PathRewriteMode(*o.PathRewrite) } @@ -408,15 +452,49 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro s.Domain = req.Domain s.AccountID = accountID - targets := make([]*Target, 0, len(req.Targets)) - for i, apiTarget := range req.Targets { + if req.Mode != nil { + s.Mode = string(*req.Mode) + } + if req.ListenPort != nil { + s.ListenPort = uint16(*req.ListenPort) //nolint:gosec + } + + targets, err := targetsFromAPI(accountID, req.Targets) + if err != nil { + return err + } + s.Targets = targets + s.Enabled = req.Enabled + + if req.PassHostHeader != nil { + s.PassHostHeader = *req.PassHostHeader + } + if req.RewriteRedirects != nil { + s.RewriteRedirects = *req.RewriteRedirects + } + + if req.Auth != nil { + s.Auth = authFromAPI(req.Auth) + } + + return nil +} + +func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) { + var apiTargets []api.ServiceTarget + if apiTargetsPtr != nil { + apiTargets = *apiTargetsPtr + } + + targets := make([]*Target, 0, len(apiTargets)) + for i, apiTarget := range apiTargets { target := &Target{ AccountID: accountID, Path: apiTarget.Path, - Port: apiTarget.Port, + Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer Protocol: string(apiTarget.Protocol), TargetId: apiTarget.TargetId, - TargetType: string(apiTarget.TargetType), + TargetType: TargetType(apiTarget.TargetType), Enabled: apiTarget.Enabled, } if apiTarget.Host != nil { @@ -425,49 +503,42 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro if apiTarget.Options != nil { opts, err := targetOptionsFromAPI(i, apiTarget.Options) if err != nil { - return err + return nil, err } target.Options = opts + if apiTarget.Options.ProxyProtocol != nil { + target.ProxyProtocol = *apiTarget.Options.ProxyProtocol + } } targets = append(targets, target) } - s.Targets = targets + return targets, nil +} - s.Enabled = req.Enabled - - if req.PassHostHeader != nil { - s.PassHostHeader = *req.PassHostHeader - } - - if req.RewriteRedirects != nil { - s.RewriteRedirects = *req.RewriteRedirects - } - - if req.Auth.PasswordAuth != nil { - s.Auth.PasswordAuth = &PasswordAuthConfig{ - Enabled: req.Auth.PasswordAuth.Enabled, - Password: req.Auth.PasswordAuth.Password, +func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { + var auth AuthConfig + if reqAuth.PasswordAuth != nil { + auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: reqAuth.PasswordAuth.Enabled, + Password: reqAuth.PasswordAuth.Password, } } - - if req.Auth.PinAuth != nil { - s.Auth.PinAuth = &PINAuthConfig{ - Enabled: req.Auth.PinAuth.Enabled, - Pin: req.Auth.PinAuth.Pin, + if reqAuth.PinAuth != nil { + auth.PinAuth = &PINAuthConfig{ + Enabled: reqAuth.PinAuth.Enabled, + Pin: reqAuth.PinAuth.Pin, } } - - if req.Auth.BearerAuth != nil { + if reqAuth.BearerAuth != nil { bearerAuth := &BearerAuthConfig{ - Enabled: req.Auth.BearerAuth.Enabled, + Enabled: reqAuth.BearerAuth.Enabled, } - if req.Auth.BearerAuth.DistributionGroups != nil { - bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups + if reqAuth.BearerAuth.DistributionGroups != nil { + bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups } - s.Auth.BearerAuth = bearerAuth + auth.BearerAuth = bearerAuth } - - return nil + return auth } func (s *Service) Validate() error { @@ -478,14 +549,69 @@ func (s *Service) Validate() error { return errors.New("service name exceeds maximum length of 255 characters") } - if s.Domain == "" { - return errors.New("service domain is required") - } - if len(s.Targets) == 0 { return errors.New("at least one target is required") } + if s.Mode == "" { + s.Mode = ModeHTTP + } + + switch s.Mode { + case ModeHTTP: + return s.validateHTTPMode() + case ModeTCP, ModeUDP: + return s.validateTCPUDPMode() + case ModeTLS: + return s.validateTLSMode() + default: + return fmt.Errorf("unsupported mode %q", s.Mode) + } +} + +func (s *Service) validateHTTPMode() error { + if s.Domain == "" { + return errors.New("service domain is required") + } + if s.ListenPort != 0 { + return errors.New("listen_port is not supported for HTTP services") + } + return s.validateHTTPTargets() +} + +func (s *Service) validateTCPUDPMode() error { + if s.Domain == "" { + return errors.New("domain is required for TCP/UDP services (used for cluster derivation)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TCP/UDP services") + } + if len(s.Targets) != 1 { + return errors.New("TCP/UDP services must have exactly one target") + } + if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol { + return errors.New("proxy_protocol is not supported for UDP services") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateTLSMode() error { + if s.Domain == "" { + return errors.New("domain is required for TLS services (used for SNI matching)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TLS services") + } + if s.ListenPort == 0 { + return errors.New("listen_port is required for TLS services") + } + if len(s.Targets) != 1 { + return errors.New("TLS services must have exactly one target") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateHTTPTargets() error { for i, target := range s.Targets { switch target.TargetType { case TargetTypePeer, TargetTypeHost, TargetTypeDomain: @@ -500,6 +626,9 @@ func (s *Service) Validate() error { if target.TargetId == "" { return fmt.Errorf("target %d has empty target_id", i) } + if target.ProxyProtocol { + return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i) + } if err := validateTargetOptions(i, &target.Options); err != nil { return err } @@ -508,11 +637,62 @@ func (s *Service) Validate() error { return nil } +func (s *Service) validateL4Target(target *Target) error { + if target.Port == 0 { + return errors.New("target port is required for L4 services") + } + if target.TargetId == "" { + return errors.New("target_id is required for L4 services") + } + switch target.TargetType { + case TargetTypePeer, TargetTypeHost: + // OK + case TargetTypeSubnet: + if target.Host == "" { + return errors.New("target host is required for subnet targets") + } + default: + return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) + } + if target.Path != nil && *target.Path != "" && *target.Path != "/" { + return errors.New("path is not supported for L4 services") + } + return nil +} + +// Service mode constants. const ( - maxRequestTimeout = 5 * time.Minute - maxCustomHeaders = 16 - maxHeaderKeyLen = 128 - maxHeaderValueLen = 4096 + ModeHTTP = "http" + ModeTCP = "tcp" + ModeUDP = "udp" + ModeTLS = "tls" +) + +// Target protocol constants (URL scheme for backend connections). +const ( + TargetProtoHTTP = "http" + TargetProtoHTTPS = "https" + TargetProtoTCP = "tcp" + TargetProtoUDP = "udp" +) + +// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS). +func IsL4Protocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS +} + +// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation. +// TLS is excluded because it uses SNI routing and can share ports with other TLS services. +func IsPortBasedProtocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP +} + +const ( + maxRequestTimeout = 5 * time.Minute + maxSessionIdleTimeout = 10 * time.Minute + maxCustomHeaders = 16 + maxHeaderKeyLen = 128 + maxHeaderValueLen = 4096 ) // httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. @@ -560,6 +740,15 @@ func validateTargetOptions(idx int, opts *TargetOptions) error { } } + if opts.SessionIdleTimeout != 0 { + if opts.SessionIdleTimeout <= 0 { + return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) + } + if opts.SessionIdleTimeout > maxSessionIdleTimeout { + return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout) + } + } + if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { return err } @@ -608,17 +797,49 @@ func containsCRLF(s string) bool { } func (s *Service) EventMeta() map[string]any { - return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()} + meta := map[string]any{ + "name": s.Name, + "domain": s.Domain, + "proxy_cluster": s.ProxyCluster, + "source": s.Source, + "auth": s.isAuthEnabled(), + "mode": s.Mode, + } + + if s.ListenPort != 0 { + meta["listen_port"] = s.ListenPort + } + + if len(s.Targets) > 0 { + t := s.Targets[0] + if t.ProxyProtocol { + meta["proxy_protocol"] = true + } + if t.Options.RequestTimeout != 0 { + meta["request_timeout"] = t.Options.RequestTimeout.String() + } + if t.Options.SessionIdleTimeout != 0 { + meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String() + } + } + + return meta } func (s *Service) isAuthEnabled() bool { - return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil + return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || + (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || + (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) } func (s *Service) Copy() *Service { targets := make([]*Target, len(s.Targets)) for i, target := range s.Targets { targetCopy := *target + if target.Path != nil { + p := *target.Path + targetCopy.Path = &p + } if len(target.Options.CustomHeaders) > 0 { targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders)) for k, v := range target.Options.CustomHeaders { @@ -628,6 +849,24 @@ func (s *Service) Copy() *Service { targets[i] = &targetCopy } + authCopy := s.Auth + if s.Auth.PasswordAuth != nil { + pa := *s.Auth.PasswordAuth + authCopy.PasswordAuth = &pa + } + if s.Auth.PinAuth != nil { + pa := *s.Auth.PinAuth + authCopy.PinAuth = &pa + } + if s.Auth.BearerAuth != nil { + ba := *s.Auth.BearerAuth + if len(s.Auth.BearerAuth.DistributionGroups) > 0 { + ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups)) + copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups) + } + authCopy.BearerAuth = &ba + } + return &Service{ ID: s.ID, AccountID: s.AccountID, @@ -638,12 +877,15 @@ func (s *Service) Copy() *Service { Enabled: s.Enabled, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, - Auth: s.Auth, + Auth: authCopy, Meta: s.Meta, SessionPrivateKey: s.SessionPrivateKey, SessionPublicKey: s.SessionPublicKey, Source: s.Source, SourcePeer: s.SourcePeer, + Mode: s.Mode, + ListenPort: s.ListenPort, + PortAutoAssigned: s.PortAutoAssigned, } } @@ -688,12 +930,16 @@ var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`) // ExposeServiceRequest contains the parameters for creating a peer-initiated expose service. type ExposeServiceRequest struct { NamePrefix string - Port int - Protocol string - Domain string - Pin string - Password string - UserGroups []string + Port uint16 + Mode string + // TargetProtocol is the protocol used to connect to the peer backend. + // For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp". + TargetProtocol string + Domain string + Pin string + Password string + UserGroups []string + ListenPort uint16 } // Validate checks all fields of the expose request. @@ -702,12 +948,20 @@ func (r *ExposeServiceRequest) Validate() error { return errors.New("request cannot be nil") } - if r.Port < 1 || r.Port > 65535 { + if r.Port == 0 { return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port) } - if r.Protocol != "http" && r.Protocol != "https" { - return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol) + switch r.Mode { + case ModeHTTP, ModeTCP, ModeUDP, ModeTLS: + default: + return fmt.Errorf("unsupported mode %q", r.Mode) + } + + if IsL4Protocol(r.Mode) { + if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 { + return fmt.Errorf("authentication is not supported for %s mode", r.Mode) + } } if r.Pin != "" && !pinRegexp.MatchString(r.Pin) { @@ -729,55 +983,79 @@ func (r *ExposeServiceRequest) Validate() error { // ToService builds a Service from the expose request. func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service { - service := &Service{ + svc := &Service{ AccountID: accountID, Name: serviceName, + Mode: r.Mode, Enabled: true, - Targets: []*Target{ - { - AccountID: accountID, - Port: r.Port, - Protocol: r.Protocol, - TargetId: peerID, - TargetType: TargetTypePeer, - Enabled: true, - }, + } + + // If domain is empty, CreateServiceFromPeer generates a unique subdomain. + // When explicitly provided, the service name is prepended as a subdomain. + if r.Domain != "" { + svc.Domain = serviceName + "." + r.Domain + } + + if IsL4Protocol(r.Mode) { + svc.ListenPort = r.Port + if r.ListenPort > 0 { + svc.ListenPort = r.ListenPort + } + } + + var targetProto string + switch { + case !IsL4Protocol(r.Mode): + targetProto = TargetProtoHTTP + if r.TargetProtocol != "" { + targetProto = r.TargetProtocol + } + case r.Mode == ModeUDP: + targetProto = TargetProtoUDP + default: + targetProto = TargetProtoTCP + } + svc.Targets = []*Target{ + { + AccountID: accountID, + Port: r.Port, + Protocol: targetProto, + TargetId: peerID, + TargetType: TargetTypePeer, + Enabled: true, }, } - if r.Domain != "" { - service.Domain = serviceName + "." + r.Domain - } - if r.Pin != "" { - service.Auth.PinAuth = &PINAuthConfig{ + svc.Auth.PinAuth = &PINAuthConfig{ Enabled: true, Pin: r.Pin, } } if r.Password != "" { - service.Auth.PasswordAuth = &PasswordAuthConfig{ + svc.Auth.PasswordAuth = &PasswordAuthConfig{ Enabled: true, Password: r.Password, } } if len(r.UserGroups) > 0 { - service.Auth.BearerAuth = &BearerAuthConfig{ + svc.Auth.BearerAuth = &BearerAuthConfig{ Enabled: true, DistributionGroups: r.UserGroups, } } - return service + return svc } // ExposeServiceResponse contains the result of a successful peer expose creation. type ExposeServiceResponse struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } // GenerateExposeName generates a random service name for peer-exposed services. diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 79c98fc14..a8a8ae5d6 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) { func TestValidate_NoTargets(t *testing.T) { rp := validProxy() rp.Targets = nil - assert.ErrorContains(t, rp.Validate(), "at least one target") + assert.ErrorContains(t, rp.Validate(), "at least one target is required") } func TestValidate_EmptyTargetId(t *testing.T) { @@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) { func TestIsDefaultPort(t *testing.T) { tests := []struct { scheme string - port int + port uint16 want bool }{ {"http", 80, true}, @@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { name string protocol string host string - port int + port uint16 wantTarget string }{ { @@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) { func TestExposeServiceRequest_ToService(t *testing.T) { t.Run("basic HTTP service", func(t *testing.T) { req := &ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } service := req.ToService("account-1", "peer-1", "mysvc") @@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.Len(t, service.Targets, 1) target := service.Targets[0] - assert.Equal(t, 8080, target.Port) + assert.Equal(t, uint16(8080), target.Port) assert.Equal(t, "http", target.Protocol) assert.Equal(t, "peer-1", target.TargetId) assert.Equal(t, TargetTypePeer, target.TargetType) @@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.NotNil(t, service.Auth.BearerAuth) }) } + +func TestValidate_TLSOnly(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TLSMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 0, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "listen_port is required") +} + +func TestValidate_TLSMissingDomain(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "domain is required") +} + +func TestValidate_TCPValid(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TCPMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)") +} + +func TestValidate_L4MultipleTargets(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + {TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "exactly one target") +} + +func TestValidate_L4TargetMissingPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "port is required") +} + +func TestValidate_TLSInvalidTargetType(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.Error(t, rp.Validate()) +} + +func TestValidate_TLSSubnetValid(t *testing.T) { + rp := &Service{ + Name: "tls-subnet", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { + rp := validProxy() + rp.Targets[0].ProxyProtocol = true + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP") +} + +func TestValidate_UDPProxyProtocolRejected(t *testing.T) { + rp := &Service{ + Name: "udp-svc", + Mode: "udp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP") +} + +func TestValidate_TCPProxyProtocolAllowed(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) { + tests := []struct { + name string + req ExposeServiceRequest + }{ + { + name: "tcp with pin", + req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"}, + }, + { + name: "udp with password", + req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"}, + }, + { + name: "tls with user groups", + req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") + }) + } +} + +func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) { + req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"} + require.NoError(t, req.Validate()) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eb13a15e3..88d37ca80 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 29a8953ac..a32cf6046 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" @@ -211,6 +212,9 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) + s.AfterInit(func(s *BaseServer) { + m.SetClusterCapabilities(s.ServiceProxyController()) + }) return &m }) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index c74fa2660..ef417d3cf 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -107,7 +107,8 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, AutoUpdate: &proto.AutoUpdateSettings{ - Version: settings.AutoUpdateVersion, + Version: settings.AutoUpdateVersion, + AlwaysUpdate: settings.AutoUpdateAlways, }, } } diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go index c444471b0..1b87f7ede 100644 --- a/management/internals/shared/grpc/expose_service.go +++ b/management/internals/shared/grpc/expose_service.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "fmt" pb "github.com/golang/protobuf/proto" // nolint log "github.com/sirupsen/logrus" @@ -39,23 +40,38 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } + if exposeReq.Port > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port) + } + if exposeReq.ListenPort > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort) + } + + mode, err := exposeProtocolToString(exposeReq.Protocol) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{ - NamePrefix: exposeReq.NamePrefix, - Port: int(exposeReq.Port), - Protocol: exposeProtocolToString(exposeReq.Protocol), - Domain: exposeReq.Domain, - Pin: exposeReq.Pin, - Password: exposeReq.Password, - UserGroups: exposeReq.UserGroups, + NamePrefix: exposeReq.NamePrefix, + Port: uint16(exposeReq.Port), //nolint:gosec // validated above + Mode: mode, + TargetProtocol: exposeTargetProtocol(exposeReq.Protocol), + Domain: exposeReq.Domain, + Pin: exposeReq.Pin, + Password: exposeReq.Password, + UserGroups: exposeReq.UserGroups, + ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above }) if err != nil { return nil, mapExposeError(ctx, err) } return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{ - ServiceName: created.ServiceName, - ServiceUrl: created.ServiceURL, - Domain: created.Domain, + ServiceName: created.ServiceName, + ServiceUrl: created.ServiceURL, + Domain: created.Domain, + PortAutoAssigned: created.PortAutoAssigned, }) } @@ -77,7 +93,12 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) ( return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, renewReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -102,7 +123,12 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (* return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, stopReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -180,13 +206,46 @@ func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) { s.reverseProxyManager = mgr } -func exposeProtocolToString(p proto.ExposeProtocol) string { +// resolveServiceID looks up the service by its globally unique domain. +func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) { + if domain == "" { + return "", status.Errorf(codes.InvalidArgument, "domain is required") + } + + svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain) + if err != nil { + return "", err + } + return svc.ID, nil +} + +func exposeProtocolToString(p proto.ExposeProtocol) (string, error) { switch p { - case proto.ExposeProtocol_EXPOSE_HTTP: - return "http" - case proto.ExposeProtocol_EXPOSE_HTTPS: - return "https" + case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS: + return "http", nil + case proto.ExposeProtocol_EXPOSE_TCP: + return "tcp", nil + case proto.ExposeProtocol_EXPOSE_UDP: + return "udp", nil + case proto.ExposeProtocol_EXPOSE_TLS: + return "tls", nil default: - return "http" + return "", fmt.Errorf("unsupported expose protocol: %v", p) + } +} + +// exposeTargetProtocol returns the target protocol for the given expose protocol. +// For HTTP mode, this is http or https (the scheme used to connect to the backend). +// For L4 modes, this is tcp or udp (the transport used to connect to the backend). +func exposeTargetProtocol(p proto.ExposeProtocol) string { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTPS: + return rpservice.TargetProtoHTTPS + case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS: + return rpservice.TargetProtoTCP + case proto.ExposeProtocol_EXPOSE_UDP: + return rpservice.TargetProtoUDP + default: + return rpservice.TargetProtoHTTP } } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index e2d0f1abe..31a0ba0db 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -32,6 +32,7 @@ import ( proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/management/proto" + nbstatus "github.com/netbirdio/netbird/shared/management/status" ) type ProxyOIDCConfig struct { @@ -45,12 +46,6 @@ type ProxyOIDCConfig struct { KeysLocation string } -// ClusterInfo contains information about a proxy cluster. -type ClusterInfo struct { - Address string - ConnectedProxies int -} - // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -61,9 +56,9 @@ type ProxyServiceServer struct { // Manager for access logs accessLogManager accesslogs.Manager + mu sync.RWMutex // Manager for reverse proxy operations serviceManager rpservice.Manager - // ProxyController for service updates and cluster management proxyController proxy.Controller @@ -84,23 +79,26 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { - proxyID string - address string - stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + proxyID string + address string + capabilities *proto.ProxyCapabilities + stream proto.ProxyService_GetMappingUpdateServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } // NewProxyServiceServer creates a new proxy service server. func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, oidcConfig: oidcConfig, @@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + cancel: cancel, } go s.cleanupStaleProxies(ctx) return s @@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { } } +// Close stops background goroutines. +func (s *ProxyServiceServer) Close() { + s.cancel() +} + +// SetServiceManager sets the service manager. Must be called before serving. func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { + s.mu.Lock() + defer s.mu.Unlock() s.serviceManager = manager } +// SetProxyController sets the proxy controller. Must be called before serving. func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { + s.mu.Lock() + defer s.mu.Unlock() s.proxyController = proxyController } @@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ - proxyID: proxyID, - address: proxyAddress, - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, + proxyID: proxyID, + address: proxyAddress, + capabilities: req.GetCapabilities(), + stream: stream, + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), + ctx: connCtx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) @@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { } // sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only services matching the proxy's cluster address are sent. +// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return fmt.Errorf("get services from store: %w", err) - } - if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } - var filtered []*rpservice.Service - for _, service := range services { - if !service.Enabled { - continue - } - if service.ProxyCluster == "" || service.ProxyCluster != conn.address { - continue - } - filtered = append(filtered, service) + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err } - if len(filtered) == 0 { + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { @@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } - for i, service := range filtered { - // Generate one-time authentication token for each service in the snapshot - // Tokens are not persistent on the proxy, so we need to generate new ones on reconnection + for i, m := range mappings { + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{m}, + InitialSyncComplete: i == len(mappings)-1, + }); err != nil { + return fmt.Errorf("send proxy mapping: %w", err) + } + } + + return nil +} + +func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) + if err != nil { + return nil, fmt.Errorf("get services from store: %w", err) + } + + var mappings []*proto.ProxyMapping + for _, service := range services { + if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { + continue + } + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) if err != nil { log.WithFields(log.Fields{ @@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec continue } - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{ - service.ToProtoMapping( - rpservice.Create, // Initial snapshot, all records are "new" for the proxy. - token, - s.GetOIDCValidationConfig(), - ), - }, - InitialSyncComplete: i == len(filtered)-1, - }); err != nil { - log.WithFields(log.Fields{ - "domain": service.Domain, - "account": service.AccountID, - }).WithError(err).Error("failed to send proxy mapping") - return fmt.Errorf("send proxy mapping: %w", err) - } + m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + mappings = append(mappings, m) } - - return nil + return mappings, nil } // isProxyAddressValid validates a proxy address @@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool { func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { - case msg := <-conn.sendChan: - if err := conn.stream.Send(msg); err != nil { + case resp := <-conn.sendChan: + if err := conn.stream.Send(resp); err != nil { errChan <- err return } @@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - msg := s.perProxyMessage(update, conn.proxyID) - if msg == nil { + resp := s.perProxyMessage(update, conn.proxyID) + if resp == nil { return true } select { - case conn.sendChan <- msg: + case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) @@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { Auth: m.Auth, PassHostHeader: m.PassHostHeader, RewriteRedirects: m.RewriteRedirects, + Mode: m.Mode, + ListenPort: m.ListenPort, } } +// ClusterSupportsCustomPorts returns whether any connected proxy in the given +// cluster reports custom port support. Returns nil if no proxy has reported +// capabilities (old proxies that predate the field). +func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { + if s.proxyController == nil { + return nil + } + + var hasCapabilities bool + for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { + connVal, ok := s.connectedProxies.Load(pid) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { + continue + } + if *conn.capabilities.SupportsCustomPorts { + return ptr(true) + } + hasCapabilities = true + } + if hasCapabilities { + return ptr(false) + } + return nil +} + func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { @@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic return token, nil } -// SendStatusUpdate handles status updates from proxy clients +// SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { accountID := req.GetAccountId() serviceID := req.GetServiceId() @@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required") } + internalStatus := protoStatusToInternal(protoStatus) + + if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { + sErr, isNbErr := nbstatus.FromError(err) + if isNbErr && sErr.Type() == nbstatus.NotFound { + return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID) + } + log.WithContext(ctx).WithError(err).Error("failed to update service status") + return nil, status.Errorf(codes.Internal, "update service status: %v", err) + } + if certificateIssued { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") @@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se }).Info("Certificate issued timestamp updated") } - internalStatus := protoStatusToInternal(protoStatus) - - if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { - log.WithContext(ctx).WithError(err).Error("failed to update service status") - return nil, status.Errorf(codes.Internal, "update service status: %v", err) - } - log.WithFields(log.Fields{ "service_id": serviceID, "account_id": accountID, @@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return &proto.SendStatusUpdateResponse{}, nil } -// protoStatusToInternal maps proto status to internal status +// protoStatusToInternal maps proto status to internal service status. func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { switch protoStatus { case proto.ProxyStatus_PROXY_STATUS_PENDING: @@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user * return fmt.Errorf("user not in allowed groups") } + +func ptr[T any](v T) *T { return &v } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index b7abb28b6..1a4ea3330 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -53,6 +53,10 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return ptr(true) +} + func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -70,11 +74,17 @@ func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil) +} + +// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. +func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ - proxyID: proxyID, - address: clusterAddr, - sendChan: ch, + proxyID: proxyID, + address: clusterAddr, + capabilities: caps, + sendChan: ch, } s.connectedProxies.Store(proxyID, conn) @@ -83,15 +93,29 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan return ch } -func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { +// drainMapping drains a single ProxyMapping from the channel. +func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping { select { - case msg := <-ch: - return msg + case resp := <-ch: + if len(resp.Mapping) > 0 { + return resp.Mapping[0] + } + return nil case <-time.After(time.Second): return nil } } +// drainEmpty checks if a channel has no message within timeout. +func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { + select { + case <-ch: + return false + case <-time.After(100 * time.Millisecond): + return true + } +} + func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) @@ -129,10 +153,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { tokens := make([]string, numProxies) for i, ch := range channels { - resp := drainChannel(ch) - require.NotNil(t, resp, "proxy %d should receive a message", i) - require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i) - msg := resp.Mapping[0] + msg := drainMapping(ch) + require.NotNil(t, msg, "proxy %d should receive a message", i) assert.Equal(t, mapping.Domain, msg.Domain) assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) @@ -181,16 +203,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) // Delete operations should not generate tokens - assert.Empty(t, resp1.Mapping[0].AuthToken) - assert.Empty(t, resp2.Mapping[0].AuthToken) + assert.Empty(t, msg1.AuthToken) + assert.Empty(t, msg2.AuthToken) } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { @@ -224,15 +244,10 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { s.SendServiceUpdate(update) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) - - msg1 := resp1.Mapping[0] - msg2 := resp2.Mapping[0] + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg2.AuthToken) @@ -324,3 +339,314 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } + +func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + // Proxy A supports custom ports. + chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Proxy B does NOT support custom ports (shared cloud proxy). + chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + ctx := context.Background() + + // TLS passthrough works on all proxies regardless of custom port support. + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) + + msgA := drainMapping(chA) + msgB := drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") + assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + + // Send an HTTP mapping: both should receive it. + httpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-http", + AccountId: "account-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}}, + } + + s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) + + msgA = drainMapping(chA) + msgB = drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") + assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") +} + +func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) + + msg := drainMapping(chShared) + assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") +} + +// TestServiceModifyNotifications exercises every possible modification +// scenario for an existing service, verifying the correct update types +// reach the correct clusters. +func TestServiceModifyNotifications(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + chs := map[string]chan *proto.GetMappingUpdateResponse{ + "cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + "cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + } + return s, chs + } + + httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}}, + } + } + + tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}}, + } + } + + ctx := context.Background() + + t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg, "cluster-a should receive update") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token") + assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update") + }) + + t.Run("auth config changed sends MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.Auth = &proto.Authentication{Password: true, Pin: true} + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.True(t, msg.Auth.Password) + assert.True(t, msg.Auth.Pin) + }) + + t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Equal(t, "tls", msg.Mode, "mode should be tls") + assert.Equal(t, int32(8443), msg.ListenPort) + assert.Len(t, msg.Path, 1, "should have one path entry with target address") + assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target) + }) + + t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Empty(t, msg.Mode, "mode should be empty for HTTP") + assert.True(t, len(msg.Path) > 0) + }) + + t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) { + s, chs := newServer() + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 9443 + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, int32(9443), msg.ListenPort) + }) + + t.Run("disable sends REMOVED to cluster", func(t *testing.T) { + s, chs := newServer() + // Manager sends Delete when service is disabled + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type) + assert.Empty(t, msg.AuthToken, "DELETE should not have token") + }) + + t.Run("enable sends CREATED to cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type) + assert.NotEmpty(t, msg.AuthToken) + }) + + t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) { + s, chs := newServer() + // This is the pattern the manager produces: + // 1. DELETE on old cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + // 2. CREATE on new cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b") + + msgA := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgA, "old cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type) + + msgB := drainMapping(chs["cluster-b"]) + require.NotNil(t, msgB, "new cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type) + assert.NotEmpty(t, msgB.AuthToken) + }) + + t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) { + s, chs := newServer() + // Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain). + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msgDel := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgDel, "same cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type) + + msgCreate := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgCreate, "same cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type) + assert.NotEmpty(t, msgCreate.AuthToken) + }) + + t.Run("TLS passthrough sent to all proxies", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + // TLS passthrough works on all proxies regardless of custom port support + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) + + msgModern := drainMapping(chModern) + require.NotNil(t, msgModern, "modern proxy receives TLS update") + assert.Equal(t, "tls", msgModern.Mode) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough") + }) + + t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 0 // default port + s.SendServiceUpdateToCluster(ctx, mapping, cluster) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port") + }) + + t.Run("passthrough and rewrite flags propagated", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.PassHostHeader = true + mapping.RewriteRedirects = true + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.True(t, msg.PassHostHeader) + assert.True(t, msg.RewriteRedirects) + }) + + t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, + Id: "svc-multi", + AccountId: "acct-1", + Domain: "multi.example.com", + Path: []*proto.PathMapping{ + {Path: "/", Target: "http://10.0.0.1:8080"}, + {Path: "/api", Target: "http://10.0.0.2:9090"}, + {Path: "/ws", Target: "http://10.0.0.3:3000"}, + }, + } + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + require.Len(t, msg.Path, 3, "all paths should be present") + assert.Equal(t, "/", msg.Path[0].Path) + assert.Equal(t, "/api", msg.Path[1].Path) + assert.Equal(t, "/ws", msg.Path[2].Path) + }) +} diff --git a/management/server/account.go b/management/server/account.go index 01d0eebfa..75db36a5f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -335,7 +335,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.DNSDomain != newSettings.DNSDomain || - oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || + oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { updateAccountPeers = true } @@ -376,6 +377,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err @@ -493,6 +495,16 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateAlwaysSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { + if newSettings.AutoUpdateAlways { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysDisabled, nil) + } + } +} + func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { oldEnabled := oldSettings.PeerExposeEnabled newEnabled := newSettings.PeerExposeEnabled diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index fa6c45856..ac53a9fa8 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -63,11 +63,20 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID) startTime := time.Now() - ac.getAccountRequestCh <- req - result := <-req.ResultChan - log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) - return result.Account, result.Err + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ac.getAccountRequestCh <- req: + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-req.ResultChan: + log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err + } } func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { @@ -86,7 +95,14 @@ func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, acco result := &AccountResult{Account: account, Err: err} for _, req := range requests { - req.ResultChan <- result + if account != nil { + // Shallow copy the account so each goroutine gets its own struct value. + // This prevents data races when callers mutate fields like Policies. + accountCopy := *account + req.ResultChan <- &AccountResult{Account: &accountCopy, Err: err} + } else { + req.ResultChan <- result + } close(req.ResultChan) } } diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 948d599ba..ddc3e00c3 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -220,6 +220,11 @@ const ( // AccountPeerExposeDisabled indicates that a user disabled peer expose for the account AccountPeerExposeDisabled Activity = 115 + // AccountAutoUpdateAlwaysEnabled indicates that a user enabled always auto-update for the account + AccountAutoUpdateAlwaysEnabled Activity = 116 + // AccountAutoUpdateAlwaysDisabled indicates that a user disabled always auto-update for the account + AccountAutoUpdateAlwaysDisabled Activity = 117 + // DomainAdded indicates that a user added a custom domain DomainAdded Activity = 118 // DomainDeleted indicates that a user deleted a custom domain @@ -339,6 +344,8 @@ var activityMap = map[Activity]Code{ UserCreated: {"User created", "user.create"}, AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, + AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"}, + AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"}, IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ddeda6d7f..ad36b9d46 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -174,9 +174,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } - // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 27a57c434..cc5567e3d 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -225,6 +225,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS return nil, fmt.Errorf("invalid AutoUpdateVersion") } } + if req.Settings.AutoUpdateAlways != nil { + returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways + } return returnSettings, nil } @@ -348,6 +351,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, + AutoUpdateAlways: &settings.AutoUpdateAlways, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 6cbd5908d..739dfe2f6 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -146,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -171,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -196,6 +199,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -221,6 +225,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -246,6 +251,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go index 9447704cb..ddece7152 100644 --- a/management/server/http/middleware/bypass/bypass.go +++ b/management/server/http/middleware/bypass/bypass.go @@ -51,19 +51,28 @@ func GetList() []string { // This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { byPassMutex.RLock() - defer byPassMutex.RUnlock() - + var matched bool for bypassPath := range bypassPaths { - matched, err := path.Match(bypassPath, requestPath) + m, err := path.Match(bypassPath, requestPath) if err != nil { - log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) + list := make([]string, 0, len(bypassPaths)) + for k := range bypassPaths { + list = append(list, k) + } + log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %v: %v", bypassPath, requestPath, list, err) continue } - if matched { - h.ServeHTTP(w, r) - return true + if m { + matched = true + break } } + byPassMutex.RUnlock() + + if matched { + h.ServeHTTP(w, r) + return true + } return false } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 462013963..6bd269a2c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric/noop" "github.com/netbirdio/management-integrations/integrations" + accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -113,6 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } + domainManager.SetClusterCapabilities(serviceProxyController) serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) diff --git a/management/server/job/channel.go b/management/server/job/channel.go index c4dc98a68..c4454c4c9 100644 --- a/management/server/job/channel.go +++ b/management/server/job/channel.go @@ -28,7 +28,13 @@ func NewChannel() *Channel { return jc } -func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error { +func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) (err error) { + defer func() { + if r := recover(); r != nil { + err = ErrJobChannelClosed + } + }() + select { case <-ctx.Done(): return ctx.Err() diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bfefce388..8732cf89f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -219,7 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { servicesStatusActive int servicesStatusPending int servicesStatusError int - servicesTargetType map[string]int + servicesTargetType map[rpservice.TargetType]int servicesAuthPassword int servicesAuthPin int servicesAuthOIDC int @@ -232,7 +232,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rulesDirection = make(map[string]int) activeUsersLastDay = make(map[string]struct{}) embeddedIdpTypes = make(map[string]int) - servicesTargetType = make(map[string]int) + servicesTargetType = make(map[rpservice.TargetType]int) uptime = time.Since(w.startupTime).Seconds() connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() @@ -434,7 +434,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["custom_domains_validated"] = customDomainsValidated for targetType, count := range servicesTargetType { - metricsProperties["services_target_type_"+targetType] = count + metricsProperties["services_target_type_"+string(targetType)] = count } for idpType, count := range embeddedIdpTypes { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 5997c10e2..b3fbfe141 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -30,6 +30,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -4996,6 +4997,7 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpse return service, nil } + func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { @@ -5041,16 +5043,16 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS } // RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service. -func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { result := s.db.Model(&rpservice.Service{}). - Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral). Update("meta_last_renewed_at", time.Now()) if result.Error != nil { log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error) return status.Errorf(status.Internal, "renew ephemeral service") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) + return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID) } return nil } @@ -5133,6 +5135,37 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock return id != "", nil } +// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. +func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster and port") + } + + return services, nil +} + +// GetServicesByCluster returns all services for the given proxy cluster. +func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster") + } + return services, nil +} + func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 1fa99fd05..8bb52f38a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -261,10 +261,12 @@ type Store interface { GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) - RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error + RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) + GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) + GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 130df4485..e75e35b94 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1991,6 +1991,36 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength) } +// GetServicesByCluster mocks base method. +func (m *MockStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByCluster", ctx, lockStrength, proxyCluster) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByCluster indicates an expected call of GetServicesByCluster. +func (mr *MockStoreMockRecorder) GetServicesByCluster(ctx, lockStrength, proxyCluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByCluster", reflect.TypeOf((*MockStore)(nil).GetServicesByCluster), ctx, lockStrength, proxyCluster) +} + +// GetServicesByClusterAndPort mocks base method. +func (m *MockStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster, mode string, listenPort uint16) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByClusterAndPort", ctx, lockStrength, proxyCluster, mode, listenPort) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByClusterAndPort indicates an expected call of GetServicesByClusterAndPort. +func (mr *MockStoreMockRecorder) GetServicesByClusterAndPort(ctx, lockStrength, proxyCluster, mode, listenPort interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByClusterAndPort", reflect.TypeOf((*MockStore)(nil).GetServicesByClusterAndPort), ctx, lockStrength, proxyCluster, mode, listenPort) +} + // GetSetupKeyByID mocks base method. func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) { m.ctrl.T.Helper() @@ -2447,17 +2477,17 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID } // RenewEphemeralService mocks base method. -func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewEphemeralService indicates an expected call of RenewEphemeralService. -func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, serviceID) } // RevokeProxyAccessToken mocks base method. diff --git a/management/server/types/account.go b/management/server/types/account.go index 6145ceeb2..269fc7a88 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -907,8 +907,8 @@ func (a *Account) Copy() *Account { } services := []*service.Service{} - for _, service := range a.Services { - services = append(services, service.Copy()) + for _, svc := range a.Services { + services = append(services, svc.Copy()) } return &Account{ @@ -1605,12 +1605,12 @@ func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { networkResourceGroups := a.getNetworkResourceGroups(resourceId) for _, policy := range a.Policies { - if !policy.Enabled { + if policy == nil || !policy.Enabled { continue } for _, rule := range policy.Rules { - if !rule.Enabled { + if rule == nil || !rule.Enabled { continue } @@ -1812,15 +1812,18 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) { } a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster) } + } func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { + proxyPeers := proxyPeersByCluster[service.ProxyCluster] for _, target := range service.Targets { if !target.Enabled { continue } - a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster]) + a.injectTargetProxyPolicies(ctx, service, target, proxyPeers) } + } func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { @@ -1840,13 +1843,13 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *servic } } -func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) { +func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (uint16, bool) { if target.Port != 0 { return target.Port, true } switch target.Protocol { - case "https": + case "https", "tls": return 443, true case "http": return 80, true @@ -1856,17 +1859,23 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) } } -func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { - policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) +func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port uint16, path string) *Policy { + policyID := fmt.Sprintf("proxy-access-%s-%s-%s", svc.ID, proxyPeer.ID, path) + + protocol := PolicyRuleProtocolTCP + if svc.Mode == service.ModeUDP { + protocol = PolicyRuleProtocolUDP + } + return &Policy{ ID: policyID, - Name: fmt.Sprintf("Proxy Access to %s", service.Name), + Name: fmt.Sprintf("Proxy Access to %s", svc.Name), Enabled: true, Rules: []*PolicyRule{ { ID: policyID, PolicyID: policyID, - Name: fmt.Sprintf("Allow access to %s", service.Name), + Name: fmt.Sprintf("Allow access to %s", svc.Name), Enabled: true, SourceResource: Resource{ ID: proxyPeer.ID, @@ -1877,12 +1886,12 @@ func (a *Account) createProxyPolicy(service *service.Service, target *service.Ta Type: ResourceType(target.TargetType), }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: protocol, Action: PolicyTrafficActionAccept, PortRanges: []RulePortRange{ { - Start: uint16(port), - End: uint16(port), + Start: port, + End: port, }, }, }, diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index 1eb25cecc..bd4244546 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -368,7 +368,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes( func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { peerInGroups := false - filteredPeerIDs := make([]string, 0, len(a.Peers)) + filteredPeerIDs := make([]string, 0, len(groups)) seenPeerIds := make(map[string]struct{}, len(groups)) for _, gid := range groups { @@ -378,7 +378,7 @@ func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerI } if group.IsGroupAll() || len(groups) == 1 { - filteredPeerIDs = filteredPeerIDs[:0] + filteredPeerIDs = make([]string, 0, len(group.Peers)) peerInGroups = false for _, pid := range group.Peers { peer, ok := a.Peers[pid] diff --git a/management/server/types/network.go b/management/server/types/network.go index d3708d80a..0d13de10f 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -152,6 +152,8 @@ func (n *Network) CurrentSerial() uint64 { } func (n *Network) Copy() *Network { + n.Mu.Lock() + defer n.Mu.Unlock() return &Network{ Identifier: n.Identifier, Net: n.Net, diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index ab6b006e6..23d84a994 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -134,7 +134,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { sourcePeers, ) - dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID) + dnsManagementStatus := c.getPeerDNSManagementStatusFromGroups(peerGroups) dnsUpdate := nbdns.Config{ ServiceEnable: dnsManagementStatus, } @@ -152,7 +152,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { customZones = append(customZones, c.AccountZones...) dnsUpdate.CustomZones = customZones - dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID) + dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups) } return &NetworkMap{ @@ -278,6 +278,16 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peers := make([]*nbpeer.Peer, 0) return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + + protocolStr := string(protocol) + actionStr := string(rule.Action) + dirStr := strconv.Itoa(direction) + portsJoined := strings.Join(rule.Ports, ",") + for _, peer := range groupPeers { if peer == nil { continue @@ -288,21 +298,18 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peersExists[peer.ID] = struct{}{} } - protocol := rule.Protocol - if protocol == PolicyRuleProtocolNetbirdSSH { - protocol = PolicyRuleProtocolTCP - } + peerIP := net.IP(peer.IP).String() fr := FirewallRule{ PolicyID: rule.ID, - PeerIP: net.IP(peer.IP).String(), + PeerIP: peerIP, Direction: direction, - Action: string(rule.Action), - Protocol: string(protocol), + Action: actionStr, + Protocol: protocolStr, } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + ruleID := rule.ID + peerIP + dirStr + + protocolStr + actionStr + portsJoined if _, ok := rulesExists[ruleID]; ok { continue } @@ -313,13 +320,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( continue } - rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ - ID: rule.ID, - Ports: rule.Ports, - PortRanges: rule.PortRanges, - Protocol: rule.Protocol, - Action: rule.Action, - }, targetPeer)...) + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -395,7 +396,7 @@ func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID str } func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { - var peersToConnect []*nbpeer.Peer + peersToConnect := make([]*nbpeer.Peer, 0, len(aclPeers)) var expiredPeers []*nbpeer.Peer for _, p := range aclPeers { @@ -410,35 +411,35 @@ func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.P return peersToConnect, expiredPeers } -func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := c.GetPeerGroups(peerID) - enabled := true +func (c *NetworkMapComponents) getPeerDNSManagementStatusFromGroups(peerGroups map[string]struct{}) bool { for _, groupID := range c.DNSSettings.DisabledManagementGroups { if _, found := peerGroups[groupID]; found { - enabled = false - break + return false } } - return enabled + return true } -func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { - groupList := c.GetPeerGroups(peerID) - +func (c *NetworkMapComponents) getPeerNSGroupsFromGroups(peerID string, groupList map[string]struct{}) []*nbdns.NameServerGroup { var peerNSGroups []*nbdns.NameServerGroup + targetPeerInfo := c.GetPeerInfo(peerID) + if targetPeerInfo == nil { + return peerNSGroups + } + + peerIPStr := targetPeerInfo.IP.String() + for _, nsGroup := range c.NameServerGroups { if !nsGroup.Enabled { continue } for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - targetPeerInfo := c.GetPeerInfo(peerID) - if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) { + if _, found := groupList[gID]; found { + if !c.peerIsNameserver(peerIPStr, nsGroup) { peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break } + break } } } @@ -446,9 +447,9 @@ func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServe return peerNSGroups } -func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { +func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns.NameServerGroup) bool { for _, ns := range nsGroup.NameServers { - if peerInfo.IP.String() == ns.IP.String() { + if peerIPStr == ns.IP.String() { return true } } @@ -489,14 +490,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } seenRoute[r.ID] = struct{}{} - routeObj := c.copyRoute(r) - routeObj.Peer = peerInfo.Key + r.Peer = peerInfo.Key if r.Enabled { - enabledRoutes = append(enabledRoutes, routeObj) + enabledRoutes = append(enabledRoutes, r) return } - disabledRoutes = append(disabledRoutes, routeObj) + disabledRoutes = append(disabledRoutes, r) } for _, r := range c.Routes { @@ -510,7 +510,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute continue } - newPeerRoute := c.copyRoute(r) + newPeerRoute := r.Copy() newPeerRoute.Peer = id newPeerRoute.PeerGroups = nil newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) @@ -519,50 +519,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } } if r.Peer == peerID { - takeRoute(c.copyRoute(r)) + takeRoute(r.Copy()) } } return enabledRoutes, disabledRoutes } -func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route { - var groups, accessControlGroups, peerGroups []string - var domains domain.List - - if r.Groups != nil { - groups = append([]string{}, r.Groups...) - } - if r.AccessControlGroups != nil { - accessControlGroups = append([]string{}, r.AccessControlGroups...) - } - if r.PeerGroups != nil { - peerGroups = append([]string{}, r.PeerGroups...) - } - if r.Domains != nil { - domains = append(domain.List{}, r.Domains...) - } - - return &route.Route{ - ID: r.ID, - AccountID: r.AccountID, - Network: r.Network, - NetworkType: r.NetworkType, - Description: r.Description, - Peer: r.Peer, - PeerID: r.PeerID, - Metric: r.Metric, - Masquerade: r.Masquerade, - NetID: r.NetID, - Enabled: r.Enabled, - Groups: groups, - AccessControlGroups: accessControlGroups, - PeerGroups: peerGroups, - Domains: domains, - KeepRoute: r.KeepRoute, - SkipAutoApply: r.SkipAutoApply, - } -} func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { var filteredRoutes []*route.Route diff --git a/management/server/types/settings.go b/management/server/types/settings.go index e165968fc..4ea79ec72 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -61,6 +61,10 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + // AutoUpdateAlways when true, updates are installed automatically in the background; + // when false, updates require user interaction from the UI + AutoUpdateAlways bool `gorm:"default:false"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -91,6 +95,7 @@ func (s *Settings) Copy() *Settings { DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + AutoUpdateAlways: s.AutoUpdateAlways, EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, } diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 50aa38b29..d82f5b7fc 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strconv" "syscall" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -34,28 +35,32 @@ var ( ) var ( - debugLogs bool - mgmtAddr string - addr string - proxyDomain string - certDir string - acmeCerts bool - acmeAddr string - acmeDir string - acmeEABKID string - acmeEABHMACKey string - acmeChallengeType string - debugEndpoint bool - debugEndpointAddr string - healthAddr string - forwardedProto string - trustedProxies string - certFile string - certKeyFile string - certLockMethod string - wgPort int - proxyProtocol bool - preSharedKey string + logLevel string + debugLogs bool + mgmtAddr string + addr string + proxyDomain string + defaultDialTimeout time.Duration + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + acmeEABKID string + acmeEABHMACKey string + acmeChallengeType string + debugEndpoint bool + debugEndpointAddr string + healthAddr string + forwardedProto string + trustedProxies string + certFile string + certKeyFile string + certLockMethod string + wildcardCertDir string + wgPort uint16 + proxyProtocol bool + preSharedKey string + supportsCustomPorts bool ) var rootCmd = &cobra.Command{ @@ -68,7 +73,9 @@ var rootCmd = &cobra.Command{ } func init() { + rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", envStringOrDefault("NB_PROXY_LOG_LEVEL", "info"), "Log level: panic, fatal, error, warn, info, debug, trace") rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") + _ = rootCmd.PersistentFlags().MarkDeprecated("debug", "use --log-level instead") rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to") rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on") rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com") @@ -87,9 +94,12 @@ func init() { rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory") rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") - rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") + rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (.crt/.key). Wildcard patterns are extracted from SANs automatically") + rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") + rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") + rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)") } // Execute runs the root command. @@ -115,7 +125,7 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken) } - level := "error" + level := logLevel if debugLogs { level = "debug" } @@ -162,9 +172,12 @@ func runServer(cmd *cobra.Command, args []string) error { ForwardedProto: forwardedProto, TrustedProxies: parsedTrustedProxies, CertLockMethod: nbacme.CertLockMethod(certLockMethod), + WildcardCertDir: wildcardCertDir, WireguardPort: wgPort, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, + SupportsCustomPorts: supportsCustomPorts, + DefaultDialTimeout: defaultDialTimeout, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) @@ -197,12 +210,24 @@ func envStringOrDefault(key string, def string) string { return v } -func envIntOrDefault(key string, def int) int { +func envUint16OrDefault(key string, def uint16) uint16 { v, exists := os.LookupEnv(key) if !exists { return def } - parsed, err := strconv.Atoi(v) + parsed, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return def + } + return uint16(parsed) +} + +func envDurationOrDefault(key string, def time.Duration) time.Duration { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + parsed, err := time.ParseDuration(v) if err != nil { return def } diff --git a/proxy/handle_mapping_stream_test.go b/proxy/handle_mapping_stream_test.go index d2ad3f67e..cb16c0814 100644 --- a/proxy/handle_mapping_stream_test.go +++ b/proxy/handle_mapping_stream_test.go @@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun func (m *mockMappingStream) SendMsg(any) error { return nil } func (m *mockMappingStream) RecvMsg(any) error { return nil } +func closedChan() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) { checker := health.NewChecker(nil, nil) s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { func TestHandleMappingStream_NilHealthChecker(t *testing.T) { s := &Server{ - Logger: log.StandardLogger(), + Logger: log.StandardLogger(), + routerReady: closedChan(), } stream := &mockMappingStream{ diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 4ba5a7755..5b05ab195 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -6,11 +6,13 @@ import ( "sync" "time" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -19,6 +21,7 @@ const ( bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours + logSendTimeout = 10 * time.Second ) type domainUsage struct { @@ -79,22 +82,63 @@ func (l *Logger) Close() { type logEntry struct { ID string - AccountID string - ServiceId string + AccountID types.AccountID + ServiceId types.ServiceID Host string Path string DurationMs int64 Method string ResponseCode int32 - SourceIp string + SourceIP netip.Addr AuthMechanism string UserId string AuthSuccess bool BytesUpload int64 BytesDownload int64 + Protocol Protocol } -func (l *Logger) log(ctx context.Context, entry logEntry) { +// Protocol identifies the transport protocol of an access log entry. +type Protocol string + +const ( + ProtocolHTTP Protocol = "http" + ProtocolTCP Protocol = "tcp" + ProtocolUDP Protocol = "udp" + ProtocolTLS Protocol = "tls" +) + +// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry. +type L4Entry struct { + AccountID types.AccountID + ServiceID types.ServiceID + Protocol Protocol + Host string // SNI hostname or listen address + SourceIP netip.Addr + DurationMs int64 + BytesUpload int64 + BytesDownload int64 +} + +// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). +// The call is non-blocking: the gRPC send happens in a background goroutine. +func (l *Logger) LogL4(entry L4Entry) { + le := logEntry{ + ID: xid.New().String(), + AccountID: entry.AccountID, + ServiceId: entry.ServiceID, + Protocol: entry.Protocol, + Host: entry.Host, + SourceIP: entry.SourceIP, + DurationMs: entry.DurationMs, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, + } + l.log(le) + l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) +} + +func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message // (although it should still get logged in the event of an error), @@ -105,31 +149,37 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { // allow for resolving that on the server. now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. go func() { - logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) defer cancel() if entry.AuthMechanism != auth.MethodOIDC.String() { entry.UserId = "" } + + var sourceIP string + if entry.SourceIP.IsValid() { + sourceIP = entry.SourceIP.String() + } + if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{ Log: &proto.AccessLog{ LogId: entry.ID, - AccountId: entry.AccountID, + AccountId: string(entry.AccountID), Timestamp: now, - ServiceId: entry.ServiceId, + ServiceId: string(entry.ServiceId), Host: entry.Host, Path: entry.Path, DurationMs: entry.DurationMs, Method: entry.Method, ResponseCode: entry.ResponseCode, - SourceIp: entry.SourceIp, + SourceIp: sourceIP, AuthMechanism: entry.AuthMechanism, UserId: entry.UserId, AuthSuccess: entry.AuthSuccess, BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, + Protocol: string(entry.Protocol), }, }); err != nil { - // If it fails to send on the gRPC connection, then at least log it to the error log. l.logger.WithFields(log.Fields{ "service_id": entry.ServiceId, "host": entry.Host, @@ -137,7 +187,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { "duration": entry.DurationMs, "method": entry.Method, "response_code": entry.ResponseCode, - "source_ip": entry.SourceIp, + "source_ip": sourceIP, "auth_mechanism": entry.AuthMechanism, "user_id": entry.UserId, "auth_success": entry.AuthSuccess, diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 7368185c0..593a77ef2 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -67,23 +67,24 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: requestID, ServiceId: capturedData.GetServiceId(), - AccountID: string(capturedData.GetAccountId()), + AccountID: capturedData.GetAccountId(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), Method: r.Method, ResponseCode: int32(sw.status), - SourceIp: sourceIp, + SourceIP: sourceIp, AuthMechanism: capturedData.GetAuthMethod(), UserId: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, BytesUpload: bytesUpload, BytesDownload: bytesDownload, + Protocol: ProtocolHTTP, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) - l.log(r.Context(), entry) + l.log(entry) // Track usage for cost monitoring (upload + download) by domain l.trackUsage(host, bytesUpload+bytesDownload) diff --git a/proxy/internal/accesslog/requestip.go b/proxy/internal/accesslog/requestip.go index f111c1322..30c483fd9 100644 --- a/proxy/internal/accesslog/requestip.go +++ b/proxy/internal/accesslog/requestip.go @@ -11,6 +11,6 @@ import ( // proxy configuration. When trustedProxies is non-empty and the direct // connection is from a trusted source, it walks X-Forwarded-For right-to-left // skipping trusted IPs. Otherwise it returns RemoteAddr directly. -func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string { +func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr { return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies) } diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index ebc15314b..a4a220ed7 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -7,9 +7,14 @@ import ( "encoding/asn1" "encoding/base64" "encoding/binary" + "encoding/pem" "fmt" + "math/rand/v2" "net" + "os" + "path/filepath" "slices" + "strings" "sync" "time" @@ -17,6 +22,8 @@ import ( "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" + "github.com/netbirdio/netbird/proxy/internal/certwatch" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -24,7 +31,7 @@ import ( var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2} type certificateNotifier interface { - NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error + NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error } type domainState int @@ -36,8 +43,8 @@ const ( ) type domainInfo struct { - accountID string - serviceID string + accountID types.AccountID + serviceID types.ServiceID state domainState err string } @@ -46,6 +53,34 @@ type metricsRecorder interface { RecordCertificateIssuance(duration time.Duration) } +// wildcardEntry maps a domain suffix (e.g. ".example.com") to a certwatch +// watcher that hot-reloads the corresponding wildcard certificate from disk. +type wildcardEntry struct { + suffix string // e.g. ".example.com" + pattern string // e.g. "*.example.com" + watcher *certwatch.Watcher +} + +// ManagerConfig holds the configuration values for the ACME certificate manager. +type ManagerConfig struct { + // CertDir is the directory used for caching ACME certificates. + CertDir string + // ACMEURL is the ACME directory URL (e.g. Let's Encrypt). + ACMEURL string + // EABKID and EABHMACKey are optional External Account Binding credentials + // required by some CAs (e.g. ZeroSSL). EABHMACKey is the base64 + // URL-encoded string provided by the CA. + EABKID string + EABHMACKey string + // LockMethod controls the cross-replica coordination strategy. + LockMethod CertLockMethod + // WildcardDir is an optional path to a directory containing wildcard + // certificate pairs (.crt / .key). Wildcard patterns are + // extracted from the certificates' SAN lists. Domains matching a + // wildcard are served from disk; all others go through ACME. + WildcardDir string +} + // Manager wraps autocert.Manager with domain tracking and cross-replica // coordination via a pluggable locking strategy. The locker prevents // duplicate ACME requests when multiple replicas share a certificate cache. @@ -57,54 +92,182 @@ type Manager struct { mu sync.RWMutex domains map[domain.Domain]*domainInfo + // wildcards holds all loaded wildcard certificates, keyed by suffix. + wildcards []wildcardEntry + certNotifier certificateNotifier logger *log.Logger metrics metricsRecorder } -// NewManager creates a new ACME certificate manager. The certDir is used -// for caching certificates. The lockMethod controls cross-replica -// coordination strategy (see CertLockMethod constants). -// eabKID and eabHMACKey are optional External Account Binding credentials -// required for some CAs like ZeroSSL. The eabHMACKey should be the base64 -// URL-encoded string provided by the CA. -func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager { +// NewManager creates a new ACME certificate manager. +func NewManager(cfg ManagerConfig, notifier certificateNotifier, logger *log.Logger, metrics metricsRecorder) (*Manager, error) { if logger == nil { logger = log.StandardLogger() } mgr := &Manager{ - certDir: certDir, - locker: newCertLocker(lockMethod, certDir, logger), + certDir: cfg.CertDir, + locker: newCertLocker(cfg.LockMethod, cfg.CertDir, logger), domains: make(map[domain.Domain]*domainInfo), certNotifier: notifier, logger: logger, metrics: metrics, } + if cfg.WildcardDir != "" { + entries, err := loadWildcardDir(cfg.WildcardDir, logger) + if err != nil { + return nil, fmt.Errorf("load wildcard certificates from %q: %w", cfg.WildcardDir, err) + } + mgr.wildcards = entries + } + var eab *acme.ExternalAccountBinding - if eabKID != "" && eabHMACKey != "" { - decodedKey, err := base64.RawURLEncoding.DecodeString(eabHMACKey) + if cfg.EABKID != "" && cfg.EABHMACKey != "" { + decodedKey, err := base64.RawURLEncoding.DecodeString(cfg.EABHMACKey) if err != nil { logger.Errorf("failed to decode EAB HMAC key: %v", err) } else { eab = &acme.ExternalAccountBinding{ - KID: eabKID, + KID: cfg.EABKID, Key: decodedKey, } - logger.Infof("configured External Account Binding with KID: %s", eabKID) + logger.Infof("configured External Account Binding with KID: %s", cfg.EABKID) } } mgr.Manager = &autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: mgr.hostPolicy, - Cache: autocert.DirCache(certDir), + Cache: autocert.DirCache(cfg.CertDir), ExternalAccountBinding: eab, Client: &acme.Client{ - DirectoryURL: acmeURL, + DirectoryURL: cfg.ACMEURL, }, } - return mgr + return mgr, nil +} + +// WatchWildcards starts watching all wildcard certificate files for changes. +// It blocks until ctx is cancelled. It is a no-op if no wildcards are loaded. +func (mgr *Manager) WatchWildcards(ctx context.Context) { + if len(mgr.wildcards) == 0 { + return + } + seen := make(map[*certwatch.Watcher]struct{}) + var wg sync.WaitGroup + for i := range mgr.wildcards { + w := mgr.wildcards[i].watcher + if _, ok := seen[w]; ok { + continue + } + seen[w] = struct{}{} + wg.Add(1) + go func() { + defer wg.Done() + w.Watch(ctx) + }() + } + wg.Wait() +} + +// loadWildcardDir scans dir for .crt files, pairs each with a matching .key +// file, loads them, and extracts wildcard SANs (*.example.com) to build +// the suffix lookup entries. +func loadWildcardDir(dir string, logger *log.Logger) ([]wildcardEntry, error) { + crtFiles, err := filepath.Glob(filepath.Join(dir, "*.crt")) + if err != nil { + return nil, fmt.Errorf("glob certificate files: %w", err) + } + + if len(crtFiles) == 0 { + return nil, fmt.Errorf("no .crt files found in %s", dir) + } + + var entries []wildcardEntry + + for _, crtPath := range crtFiles { + base := strings.TrimSuffix(filepath.Base(crtPath), ".crt") + keyPath := filepath.Join(dir, base+".key") + if _, err := os.Stat(keyPath); err != nil { + logger.Warnf("skipping %s: no matching key file %s", crtPath, keyPath) + continue + } + + watcher, err := certwatch.NewWatcher(crtPath, keyPath, logger) + if err != nil { + logger.Warnf("skipping %s: %v", crtPath, err) + continue + } + + leaf := watcher.Leaf() + if leaf == nil { + logger.Warnf("skipping %s: no parsed leaf certificate", crtPath) + continue + } + + for _, san := range leaf.DNSNames { + suffix, ok := parseWildcard(san) + if !ok { + continue + } + entries = append(entries, wildcardEntry{ + suffix: suffix, + pattern: san, + watcher: watcher, + }) + logger.Infof("wildcard certificate loaded: %s (from %s)", san, filepath.Base(crtPath)) + } + } + + if len(entries) == 0 { + return nil, fmt.Errorf("no wildcard SANs (*.example.com) found in certificates in %s", dir) + } + + return entries, nil +} + +// parseWildcard validates a wildcard domain pattern like "*.example.com" +// and returns the suffix ".example.com" for matching. +func parseWildcard(pattern string) (suffix string, ok bool) { + if !strings.HasPrefix(pattern, "*.") { + return "", false + } + parent := pattern[1:] // ".example.com" + if strings.Count(parent, ".") < 1 { + return "", false + } + return strings.ToLower(parent), true +} + +// findWildcardEntry returns the wildcard entry that covers host, or nil. +func (mgr *Manager) findWildcardEntry(host string) *wildcardEntry { + if len(mgr.wildcards) == 0 { + return nil + } + host = strings.ToLower(host) + for i := range mgr.wildcards { + e := &mgr.wildcards[i] + if !strings.HasSuffix(host, e.suffix) { + continue + } + // Single-level match: prefix before suffix must have no dots. + prefix := strings.TrimSuffix(host, e.suffix) + if len(prefix) > 0 && !strings.Contains(prefix, ".") { + return e + } + } + return nil +} + +// WildcardPatterns returns the wildcard patterns that are currently loaded. +func (mgr *Manager) WildcardPatterns() []string { + patterns := make([]string, len(mgr.wildcards)) + for i, e := range mgr.wildcards { + patterns[i] = e.pattern + } + slices.Sort(patterns) + return patterns } func (mgr *Manager) hostPolicy(_ context.Context, host string) error { @@ -120,8 +283,39 @@ func (mgr *Manager) hostPolicy(_ context.Context, host string) error { return nil } -// AddDomain registers a domain for ACME certificate prefetching. -func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { +// GetCertificate returns the TLS certificate for the given ClientHello. +// If the requested domain matches a loaded wildcard, the static wildcard +// certificate is returned. Otherwise, the ACME autocert manager handles +// the request. +func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if e := mgr.findWildcardEntry(hello.ServerName); e != nil { + return e.watcher.GetCertificate(hello) + } + return mgr.Manager.GetCertificate(hello) +} + +// AddDomain registers a domain for certificate management. Domains that +// match a loaded wildcard are marked ready immediately (they use the +// static wildcard certificate) and the method returns true. All other +// domains go through ACME prefetch and the method returns false. +// +// When AddDomain returns true the caller is responsible for sending any +// certificate-ready notifications after the surrounding operation (e.g. +// mapping update) has committed successfully. +func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) { + name := d.PunycodeString() + if e := mgr.findWildcardEntry(name); e != nil { + mgr.mu.Lock() + mgr.domains[d] = &domainInfo{ + accountID: accountID, + serviceID: serviceID, + state: domainReady, + } + mgr.mu.Unlock() + mgr.logger.Debugf("domain %q matches wildcard %q, using static certificate", name, e.pattern) + return true + } + mgr.mu.Lock() mgr.domains[d] = &domainInfo{ accountID: accountID, @@ -131,13 +325,19 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { mgr.mu.Unlock() go mgr.prefetchCertificate(d) + return false } // prefetchCertificate proactively triggers certificate generation for a domain. // It acquires a distributed lock to prevent multiple replicas from issuing // duplicate ACME requests. The second replica will block until the first // finishes, then find the certificate in the cache. +// ACME and periodic disk reads race; whichever produces a valid certificate +// first wins. This handles cases where locking is unreliable and another +// replica already wrote the cert to the shared cache. func (mgr *Manager) prefetchCertificate(d domain.Domain) { + time.Sleep(time.Duration(rand.IntN(200)) * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -153,26 +353,105 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { defer unlock() } - hello := &tls.ClientHelloInfo{ - ServerName: name, - Conn: &dummyConn{ctx: ctx}, - } - - start := time.Now() - cert, err := mgr.GetCertificate(hello) - elapsed := time.Since(start) - if err != nil { - mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err) - mgr.setDomainState(d, domainFailed, err.Error()) + if cert, err := mgr.readCertFromDisk(ctx, name); err == nil { + mgr.logger.Infof("certificate for domain %q already on disk, skipping ACME", name) + mgr.recordAndNotify(ctx, d, name, cert, 0) return } - if mgr.metrics != nil { + // Run ACME in a goroutine so we can race it against periodic disk reads. + // autocert uses its own internal context and cannot be cancelled externally. + type acmeResult struct { + cert *tls.Certificate + err error + } + acmeCh := make(chan acmeResult, 1) + hello := &tls.ClientHelloInfo{ServerName: name, Conn: &dummyConn{ctx: ctx}} + go func() { + cert, err := mgr.GetCertificate(hello) + acmeCh <- acmeResult{cert, err} + }() + + start := time.Now() + diskTicker := time.NewTicker(5 * time.Second) + defer diskTicker.Stop() + + for { + select { + case res := <-acmeCh: + elapsed := time.Since(start) + if res.err != nil { + mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), res.err) + mgr.setDomainState(d, domainFailed, res.err.Error()) + return + } + mgr.recordAndNotify(ctx, d, name, res.cert, elapsed) + return + + case <-diskTicker.C: + cert, err := mgr.readCertFromDisk(context.Background(), name) + if err != nil { + continue + } + mgr.logger.Infof("certificate for domain %q appeared on disk after %s", name, time.Since(start).Round(time.Millisecond)) + // Drain the ACME goroutine before marking ready — autocert holds + // an internal write lock on certState while ACME is in flight. + go func() { + select { + case <-acmeCh: + default: + } + mgr.recordAndNotify(context.Background(), d, name, cert, 0) + }() + return + + case <-ctx.Done(): + mgr.logger.Warnf("prefetch certificate for domain %q timed out", name) + mgr.setDomainState(d, domainFailed, ctx.Err().Error()) + return + } + } +} + +// readCertFromDisk reads and parses a certificate directly from the autocert +// DirCache, bypassing autocert's internal certState mutex. Safe to call +// concurrently with an in-flight ACME request for the same domain. +func (mgr *Manager) readCertFromDisk(ctx context.Context, name string) (*tls.Certificate, error) { + if mgr.Cache == nil { + return nil, fmt.Errorf("no cache configured") + } + data, err := mgr.Cache.Get(ctx, name) + if err != nil { + return nil, err + } + privBlock, certsPEM := pem.Decode(data) + if privBlock == nil || !strings.Contains(privBlock.Type, "PRIVATE") { + return nil, fmt.Errorf("no private key in cache for %q", name) + } + cert, err := tls.X509KeyPair(certsPEM, pem.EncodeToMemory(privBlock)) + if err != nil { + return nil, fmt.Errorf("parse cached certificate for %q: %w", name, err) + } + if len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("parse leaf for %q: %w", name, err) + } + if time.Now().After(leaf.NotAfter) { + return nil, fmt.Errorf("cached certificate for %q expired at %s", name, leaf.NotAfter) + } + cert.Leaf = leaf + } + return &cert, nil +} + +// recordAndNotify records metrics, marks the domain ready, logs cert details, +// and notifies the cert notifier. +func (mgr *Manager) recordAndNotify(ctx context.Context, d domain.Domain, name string, cert *tls.Certificate, elapsed time.Duration) { + if elapsed > 0 && mgr.metrics != nil { mgr.metrics.RecordCertificateIssuance(elapsed) } - mgr.setDomainState(d, domainReady, "") - now := time.Now() if cert != nil && cert.Leaf != nil { leaf := cert.Leaf @@ -188,11 +467,9 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { } else { mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond)) } - mgr.mu.RLock() info := mgr.domains[d] mgr.mu.RUnlock() - if info != nil && mgr.certNotifier != nil { if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil { mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err) diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 30a27c612..ceb9ca13a 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -2,16 +2,29 @@ package acme import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" ) func TestHostPolicy(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) - mgr.AddDomain("example.com", "acc1", "rp1") + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) + mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1")) // Wait for the background prefetch goroutine to finish so the temp dir // can be cleaned up without a race. @@ -70,7 +83,8 @@ func TestHostPolicy(t *testing.T) { } func TestDomainStates(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") @@ -80,8 +94,8 @@ func TestDomainStates(t *testing.T) { // AddDomain starts as pending, then the prefetch goroutine will fail // (no real ACME server) and transition to failed. - mgr.AddDomain("a.example.com", "acc1", "rp1") - mgr.AddDomain("b.example.com", "acc1", "rp1") + mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) + mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered") @@ -100,3 +114,193 @@ func TestDomainStates(t *testing.T) { assert.Contains(t, failed, "b.example.com") assert.Empty(t, mgr.ReadyDomains()) } + +func TestParseWildcard(t *testing.T) { + tests := []struct { + pattern string + wantSuffix string + wantOK bool + }{ + {"*.example.com", ".example.com", true}, + {"*.foo.example.com", ".foo.example.com", true}, + {"*.COM", ".com", true}, // single-label TLD + {"example.com", "", false}, // no wildcard prefix + {"*example.com", "", false}, // missing dot + {"**.example.com", "", false}, // double star + {"", "", false}, + } + + for _, tc := range tests { + t.Run(tc.pattern, func(t *testing.T) { + suffix, ok := parseWildcard(tc.pattern) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantSuffix, suffix) + } + }) + } +} + +func TestMatchesWildcard(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + tests := []struct { + host string + match bool + }{ + {"foo.example.com", true}, + {"bar.example.com", true}, + {"FOO.Example.COM", true}, // case insensitive + {"example.com", false}, // bare parent + {"sub.foo.example.com", false}, // multi-level + {"notexample.com", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.host, func(t *testing.T) { + assert.Equal(t, tc.match, mgr.findWildcardEntry(tc.host) != nil) + }) + } +} + +// generateSelfSignedCert creates a temporary self-signed certificate and key +// for testing purposes. The baseName controls the output filenames: +// .crt and .key. +func generateSelfSignedCert(t *testing.T, dir, baseName string, dnsNames ...string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: dnsNames[0]}, + DNSNames: dnsNames, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certFile, err := os.Create(filepath.Join(dir, baseName+".crt")) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + keyFile, err := os.Create(filepath.Join(dir, baseName+".key")) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) +} + +func TestWildcardAddDomainSkipsACME(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + // Add a wildcard-matching domain — should be immediately ready. + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending") + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + + // Add a non-wildcard domain — should go through ACME (pending then failed). + mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2")) + assert.Equal(t, 2, mgr.TotalDomains()) + + // Wait for the ACME prefetch to fail. + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + assert.Contains(t, mgr.FailedDomains(), "other.net") +} + +func TestWildcardGetCertificate(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + + // GetCertificate for a wildcard-matching domain should return the static cert. + cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + require.NotNil(t, cert) + assert.Contains(t, cert.Leaf.DNSNames, "*.example.com") +} + +func TestMultipleWildcards(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + generateSelfSignedCert(t, wcDir, "other", "*.other.org") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns()) + + // Both wildcards should resolve. + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2")) + + assert.Equal(t, 0, mgr.PendingCerts()) + assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains()) + + // GetCertificate routes to the correct cert. + cert1, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + assert.Contains(t, cert1.Leaf.DNSNames, "*.example.com") + + cert2, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "bar.other.org"}) + require.NoError(t, err) + assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org") + + // Non-matching domain falls through to ACME. + mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3")) + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + assert.Contains(t, mgr.FailedDomains(), "custom.net") +} + +func TestWildcardDirEmpty(t *testing.T) { + wcDir := t.TempDir() + // Empty directory — no .crt files. + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no .crt files found") +} + +func TestWildcardDirNonWildcardCert(t *testing.T) { + wcDir := t.TempDir() + // Certificate without a wildcard SAN. + generateSelfSignedCert(t, wcDir, "plain", "plain.example.com") + + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no wildcard SANs") +} + +func TestNoWildcardDir(t *testing.T) { + // Empty string means no wildcard dir — pure ACME mode. + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) + assert.Empty(t, mgr.WildcardPatterns()) +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 8a966faa3..3cf86e4b3 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -44,8 +44,8 @@ type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration - AccountID string - ServiceID string + AccountID types.AccountID + ServiceID types.ServiceID } type validationResult struct { @@ -124,7 +124,7 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { func setCapturedIDs(r *http.Request, config DomainConfig) { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(types.AccountID(config.AccountID)) + cd.SetAccountId(config.AccountID) cd.SetServiceId(config.ServiceID) } } @@ -275,7 +275,7 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index bf178d432..a60e6437a 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -17,14 +18,14 @@ type urlGenerator interface { } type OIDC struct { - id string - accountId string + id types.ServiceID + accountId types.AccountID forwardedProto string client urlGenerator } // NewOIDC creates a new OIDC authentication scheme -func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC { +func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC { return OIDC{ id: id, accountId: accountId, @@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) { } res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{ - Id: o.id, - AccountId: o.accountId, + Id: string(o.id), + AccountId: string(o.accountId), RedirectUrl: redirectURL.String(), }) if err != nil { diff --git a/proxy/internal/auth/password.go b/proxy/internal/auth/password.go index 208423465..6a7eda3e1 100644 --- a/proxy/internal/auth/password.go +++ b/proxy/internal/auth/password.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const passwordFormId = "password" type Password struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPassword(client authenticator, id, accountId string) Password { +func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password { return Password{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Password{ Password: &proto.PasswordRequest{ Password: password, diff --git a/proxy/internal/auth/pin.go b/proxy/internal/auth/pin.go index c1eb56071..4d08f3dc6 100644 --- a/proxy/internal/auth/pin.go +++ b/proxy/internal/auth/pin.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const pinFormId = "pin" type Pin struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPin(client authenticator, id, accountId string) Pin { +func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin { return Pin{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Pin{ Pin: &proto.PinRequest{ Pin: pin, diff --git a/proxy/internal/certwatch/watcher.go b/proxy/internal/certwatch/watcher.go index 78ad1ab7c..6366a53c6 100644 --- a/proxy/internal/certwatch/watcher.go +++ b/proxy/internal/certwatch/watcher.go @@ -67,6 +67,13 @@ func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, erro return w.cert, nil } +// Leaf returns the parsed leaf certificate, or nil if not yet loaded. +func (w *Watcher) Leaf() *x509.Certificate { + w.mu.RLock() + defer w.mu.RUnlock() + return w.leaf +} + // Watch starts watching for certificate file changes. It blocks until // ctx is cancelled. It uses fsnotify for immediate detection and falls // back to polling if fsnotify is unavailable (e.g. on NFS). diff --git a/proxy/internal/conntrack/conn.go b/proxy/internal/conntrack/conn.go index 97055d992..8446d638f 100644 --- a/proxy/internal/conntrack/conn.go +++ b/proxy/internal/conntrack/conn.go @@ -10,10 +10,11 @@ import ( type trackedConn struct { net.Conn tracker *HijackTracker + host string } func (c *trackedConn) Close() error { - c.tracker.conns.Delete(c) + c.tracker.remove(c) return c.Conn.Close() } @@ -22,6 +23,7 @@ func (c *trackedConn) Close() error { type trackingWriter struct { http.ResponseWriter tracker *HijackTracker + host string } func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - tc := &trackedConn{Conn: conn, tracker: w.tracker} - w.tracker.conns.Store(tc, struct{}{}) + tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host} + w.tracker.add(tc) return tc, buf, nil } diff --git a/proxy/internal/conntrack/hijacked.go b/proxy/internal/conntrack/hijacked.go index d76cebc08..911f93f3d 100644 --- a/proxy/internal/conntrack/hijacked.go +++ b/proxy/internal/conntrack/hijacked.go @@ -1,7 +1,6 @@ package conntrack import ( - "net" "net/http" "sync" ) @@ -10,10 +9,14 @@ import ( // upgrades). http.Server.Shutdown does not close hijacked connections, so // they must be tracked and closed explicitly during graceful shutdown. // +// Connections are indexed by the request Host so they can be closed +// per-domain when a service mapping is removed. +// // Use Middleware as the outermost HTTP middleware to ensure hijacked // connections are tracked and automatically deregistered when closed. type HijackTracker struct { - conns sync.Map // net.Conn → struct{} + mu sync.Mutex + conns map[*trackedConn]struct{} } // Middleware returns an HTTP middleware that wraps the ResponseWriter so that @@ -21,21 +24,73 @@ type HijackTracker struct { // tracker when closed. This should be the outermost middleware in the chain. func (t *HijackTracker) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r) + next.ServeHTTP(&trackingWriter{ + ResponseWriter: w, + tracker: t, + host: hostOnly(r.Host), + }, r) }) } -// CloseAll closes all tracked hijacked connections and returns the number -// of connections that were closed. +// CloseAll closes all tracked hijacked connections and returns the count. func (t *HijackTracker) CloseAll() int { - var count int - t.conns.Range(func(key, _ any) bool { - if conn, ok := key.(net.Conn); ok { - _ = conn.Close() - count++ - } - t.conns.Delete(key) - return true - }) - return count + t.mu.Lock() + conns := t.conns + t.conns = nil + t.mu.Unlock() + + for tc := range conns { + _ = tc.Conn.Close() + } + return len(conns) +} + +// CloseByHost closes all tracked hijacked connections for the given host +// and returns the number of connections closed. +func (t *HijackTracker) CloseByHost(host string) int { + host = hostOnly(host) + t.mu.Lock() + var toClose []*trackedConn + for tc := range t.conns { + if tc.host == host { + toClose = append(toClose, tc) + } + } + for _, tc := range toClose { + delete(t.conns, tc) + } + t.mu.Unlock() + + for _, tc := range toClose { + _ = tc.Conn.Close() + } + return len(toClose) +} + +func (t *HijackTracker) add(tc *trackedConn) { + t.mu.Lock() + if t.conns == nil { + t.conns = make(map[*trackedConn]struct{}) + } + t.conns[tc] = struct{}{} + t.mu.Unlock() +} + +func (t *HijackTracker) remove(tc *trackedConn) { + t.mu.Lock() + delete(t.conns, tc) + t.mu.Unlock() +} + +// hostOnly strips the port from a host:port string. +func hostOnly(hostport string) string { + for i := len(hostport) - 1; i >= 0; i-- { + if hostport[i] == ':' { + return hostport[:i] + } + if hostport[i] < '0' || hostport[i] > '9' { + return hostport + } + } + return hostport } diff --git a/proxy/internal/conntrack/hijacked_test.go b/proxy/internal/conntrack/hijacked_test.go new file mode 100644 index 000000000..9ceefff78 --- /dev/null +++ b/proxy/internal/conntrack/hijacked_test.go @@ -0,0 +1,142 @@ +package conntrack + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing. +type fakeHijackWriter struct { + http.ResponseWriter + conn net.Conn +} + +func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} + +func TestCloseByHost(t *testing.T) { + var tracker HijackTracker + + // Simulate hijacking two connections for different hosts. + connA1, connA2 := net.Pipe() + defer connA2.Close() + connB1, connB2 := net.Pipe() + defer connB2.Close() + + twA := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "a.example.com", + } + twB := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "b.example.com", + } + + // Use fakeHijackWriter to provide the Hijack method. + twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1} + twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1} + + _, _, err := twA.Hijack() + require.NoError(t, err) + _, _, err = twB.Hijack() + require.NoError(t, err) + + tracker.mu.Lock() + assert.Equal(t, 2, len(tracker.conns), "should track 2 connections") + tracker.mu.Unlock() + + // Close only host A. + n := tracker.CloseByHost("a.example.com") + assert.Equal(t, 1, n, "should close 1 connection for host A") + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection") + tracker.mu.Unlock() + + // Verify host A's conn is actually closed. + buf := make([]byte, 1) + _, err = connA2.Read(buf) + assert.Error(t, err, "host A pipe should be closed") + + // Host B should still be alive. + go func() { _, _ = connB1.Write([]byte("x")) }() + + // Close all remaining. + n = tracker.CloseAll() + assert.Equal(t, 1, n, "should close remaining 1 connection") + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll") + tracker.mu.Unlock() +} + +func TestCloseAll(t *testing.T) { + var tracker HijackTracker + + for range 5 { + c1, c2 := net.Pipe() + defer c2.Close() + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"} + tracker.add(tc) + } + + tracker.mu.Lock() + assert.Equal(t, 5, len(tracker.conns)) + tracker.mu.Unlock() + + n := tracker.CloseAll() + assert.Equal(t, 5, n) + + // Double CloseAll is safe. + n = tracker.CloseAll() + assert.Equal(t, 0, n) +} + +func TestTrackedConn_AutoDeregister(t *testing.T) { + var tracker HijackTracker + + c1, c2 := net.Pipe() + defer c2.Close() + + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"} + tracker.add(tc) + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns)) + tracker.mu.Unlock() + + // Close the tracked conn: should auto-deregister. + require.NoError(t, tc.Close()) + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close") + tracker.mu.Unlock() +} + +func TestHostOnly(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com:443", "example.com"}, + {"example.com", "example.com"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:443", "[::1]"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, hostOnly(tt.input)) + }) + } +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 885c574bc..01b0bc8e6 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) { return } - _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT") + _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT") _, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110)) for _, item := range clients { @@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) { return } - domains := c.extractDomains(client) + services := c.extractServiceKeys(client) hasClient := "no" if hc, ok := client["has_client"].(bool); ok && hc { hasClient = "yes" @@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) { _, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n", client["account_id"], client["age"], - domains, + services, hasClient, ) } -func (c *Client) extractDomains(client map[string]any) string { - d, ok := client["domains"].([]any) +func (c *Client) extractServiceKeys(client map[string]any) string { + d, ok := client["service_keys"].([]any) if !ok || len(d) == 0 { return "-" } parts := make([]string, len(d)) - for i, domain := range d { - parts[i] = fmt.Sprint(domain) + for i, key := range d { + parts[i] = fmt.Sprint(key) } return strings.Join(parts, ", ") } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index ab75c8b72..237010922 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -189,7 +189,7 @@ type indexData struct { Version string Uptime string ClientCount int - TotalDomains int + TotalServices int CertsTotal int CertsReady int CertsPending int @@ -202,7 +202,7 @@ type indexData struct { type clientData struct { AccountID string - Domains string + Services string Age string Status string } @@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b clients := h.provider.ListClientsForDebug() sortedIDs := sortedAccountIDs(clients) - totalDomains := 0 + totalServices := 0 for _, info := range clients { - totalDomains += info.DomainCount + totalServices += info.ServiceCount } var certsTotal, certsReady, certsPending, certsFailed int @@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } resp := map[string]interface{}{ - "version": version.NetbirdVersion(), - "uptime": time.Since(h.startTime).Round(time.Second).String(), - "client_count": len(clients), - "total_domains": totalDomains, - "certs_total": certsTotal, - "certs_ready": certsReady, - "certs_pending": certsPending, - "certs_failed": certsFailed, - "clients": clientsJSON, + "version": version.NetbirdVersion(), + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "total_services": totalServices, + "certs_total": certsTotal, + "certs_ready": certsReady, + "certs_pending": certsPending, + "certs_failed": certsFailed, + "clients": clientsJSON, } if len(certsPendingDomains) > 0 { resp["certs_pending_domains"] = certsPendingDomains @@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b Version: version.NetbirdVersion(), Uptime: time.Since(h.startTime).Round(time.Second).String(), ClientCount: len(clients), - TotalDomains: totalDomains, + TotalServices: totalServices, CertsTotal: certsTotal, CertsReady: certsReady, CertsPending: certsPending, @@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) @@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } h.writeJSON(w, map[string]interface{}{ @@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) diff --git a/proxy/internal/debug/templates/clients.html b/proxy/internal/debug/templates/clients.html index 4d455b2bb..bfc25f95a 100644 --- a/proxy/internal/debug/templates/clients.html +++ b/proxy/internal/debug/templates/clients.html @@ -12,14 +12,14 @@ - + {{range .Clients}} - + diff --git a/proxy/internal/debug/templates/index.html b/proxy/internal/debug/templates/index.html index 16ab3d979..5bd25adfc 100644 --- a/proxy/internal/debug/templates/index.html +++ b/proxy/internal/debug/templates/index.html @@ -27,19 +27,19 @@
    {{range .CertsFailedDomains}}
  • {{.Domain}}: {{.Error}}
  • {{end}}
{{end}} -

Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})

+

Clients ({{.ClientCount}}) | Services ({{.TotalServices}})

{{if .Clients}}
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}
- + {{range .Clients}} - + diff --git a/proxy/internal/metrics/l4_metrics_test.go b/proxy/internal/metrics/l4_metrics_test.go new file mode 100644 index 000000000..055158828 --- /dev/null +++ b/proxy/internal/metrics/l4_metrics_test.go @@ -0,0 +1,69 @@ +package metrics_test + +import ( + "context" + "reflect" + "testing" + "time" + + promexporter "go.opentelemetry.io/otel/exporters/prometheus" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func newTestMetrics(t *testing.T) *metrics.Metrics { + t.Helper() + + exporter, err := promexporter.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } + return m +} + +func TestL4ServiceGauge(t *testing.T) { + m := newTestMetrics(t) + + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeUDP) + m.L4ServiceRemoved(types.ServiceModeTCP) +} + +func TestTCPRelayMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-1") + + m.TCPRelayStarted(acct) + m.TCPRelayStarted(acct) + m.TCPRelayEnded(acct, 10*time.Second, 1000, 500) + m.TCPRelayDialError(acct) + m.TCPRelayRejected(acct) +} + +func TestUDPSessionMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-2") + + m.UDPSessionStarted(acct) + m.UDPSessionStarted(acct) + m.UDPSessionEnded(acct) + m.UDPSessionDialError(acct) + m.UDPSessionRejected(acct) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200) + m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150) +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 68ff55fe5..573485625 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -6,12 +6,15 @@ import ( "sync" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/responsewriter" + "github.com/netbirdio/netbird/proxy/internal/types" ) +// Metrics collects OpenTelemetry metrics for the proxy. type Metrics struct { ctx context.Context requestsTotal metric.Int64Counter @@ -22,85 +25,188 @@ type Metrics struct { backendDuration metric.Int64Histogram certificateIssueDuration metric.Int64Histogram + // L4 service-level metrics. + l4Services metric.Int64UpDownCounter + + // L4 TCP connection-level metrics. + tcpActiveConns metric.Int64UpDownCounter + tcpConnsTotal metric.Int64Counter + tcpConnDuration metric.Int64Histogram + tcpBytesTotal metric.Int64Counter + + // L4 UDP session-level metrics. + udpActiveSess metric.Int64UpDownCounter + udpSessionsTotal metric.Int64Counter + udpPacketsTotal metric.Int64Counter + udpBytesTotal metric.Int64Counter + mappingsMux sync.Mutex mappingPaths map[string]int } +// New creates a Metrics instance using the given OpenTelemetry meter. func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { - requestsTotal, err := meter.Int64Counter( + m := &Metrics{ + ctx: ctx, + mappingPaths: make(map[string]int), + } + + if err := m.initHTTPMetrics(meter); err != nil { + return nil, err + } + if err := m.initL4Metrics(meter); err != nil { + return nil, err + } + + return m, nil +} + +func (m *Metrics) initHTTPMetrics(meter metric.Meter) error { + var err error + + m.requestsTotal, err = meter.Int64Counter( "proxy.http.request.counter", metric.WithUnit("1"), metric.WithDescription("Total number of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - activeRequests, err := meter.Int64UpDownCounter( + m.activeRequests, err = meter.Int64UpDownCounter( "proxy.http.active_requests", metric.WithUnit("1"), metric.WithDescription("Current in-flight requests handled by the netbird proxy"), ) if err != nil { - return nil, err + return err } - configuredDomains, err := meter.Int64UpDownCounter( + m.configuredDomains, err = meter.Int64UpDownCounter( "proxy.domains.count", metric.WithUnit("1"), metric.WithDescription("Current number of domains configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - totalPaths, err := meter.Int64UpDownCounter( + m.totalPaths, err = meter.Int64UpDownCounter( "proxy.paths.count", metric.WithUnit("1"), metric.WithDescription("Total number of paths configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - requestDuration, err := meter.Int64Histogram( + m.requestDuration, err = meter.Int64Histogram( "proxy.http.request.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - backendDuration, err := meter.Int64Histogram( + m.backendDuration, err = meter.Int64Histogram( "proxy.backend.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of peer round trip time from the netbird proxy"), ) if err != nil { - return nil, err + return err } - certificateIssueDuration, err := meter.Int64Histogram( + m.certificateIssueDuration, err = meter.Int64Histogram( "proxy.certificate.issue.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of ACME certificate issuance"), ) + return err +} + +func (m *Metrics) initL4Metrics(meter metric.Meter) error { + var err error + + m.l4Services, err = meter.Int64UpDownCounter( + "proxy.l4.services.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"), + ) if err != nil { - return nil, err + return err } - return &Metrics{ - ctx: ctx, - requestsTotal: requestsTotal, - activeRequests: activeRequests, - configuredDomains: configuredDomains, - totalPaths: totalPaths, - requestDuration: requestDuration, - backendDuration: backendDuration, - certificateIssueDuration: certificateIssueDuration, - mappingPaths: make(map[string]int), - }, nil + m.tcpActiveConns, err = meter.Int64UpDownCounter( + "proxy.tcp.active_connections", + metric.WithUnit("1"), + metric.WithDescription("Current number of active TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpConnsTotal, err = meter.Int64Counter( + "proxy.tcp.connections.total", + metric.WithUnit("1"), + metric.WithDescription("Total TCP/TLS relay connections by result and account"), + ) + if err != nil { + return err + } + + m.tcpConnDuration, err = meter.Int64Histogram( + "proxy.tcp.connection.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpBytesTotal, err = meter.Int64Counter( + "proxy.tcp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"), + ) + if err != nil { + return err + } + + m.udpActiveSess, err = meter.Int64UpDownCounter( + "proxy.udp.active_sessions", + metric.WithUnit("1"), + metric.WithDescription("Current number of active UDP relay sessions"), + ) + if err != nil { + return err + } + + m.udpSessionsTotal, err = meter.Int64Counter( + "proxy.udp.sessions.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP relay sessions by result and account"), + ) + if err != nil { + return err + } + + m.udpPacketsTotal, err = meter.Int64Counter( + "proxy.udp.packets.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP packets relayed by direction"), + ) + if err != nil { + return err + } + + m.udpBytesTotal, err = meter.Int64Counter( + "proxy.udp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through UDP relay by direction"), + ) + return err } type responseInterceptor struct { @@ -120,6 +226,13 @@ func (w *responseInterceptor) Write(b []byte) (int, error) { return size, err } +// Unwrap returns the underlying ResponseWriter so http.ResponseController +// can reach through to the original writer for Hijack/Flush operations. +func (w *responseInterceptor) Unwrap() http.ResponseWriter { + return w.PassthroughWriter +} + +// Middleware wraps an HTTP handler with request metrics. func (m *Metrics) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.requestsTotal.Add(m.ctx, 1) @@ -144,6 +257,7 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } +// RoundTripper wraps an http.RoundTripper with backend duration metrics. func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { start := time.Now() @@ -156,6 +270,7 @@ func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { }) } +// AddMapping records that a domain mapping was added. func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() @@ -175,13 +290,13 @@ func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingPaths[mapping.Host] = newPathCount } +// RemoveMapping records that a domain mapping was removed. func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() oldPathCount, exists := m.mappingPaths[mapping.Host] if !exists { - // Nothing to remove return } @@ -195,3 +310,80 @@ func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { func (m *Metrics) RecordCertificateIssuance(duration time.Duration) { m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds()) } + +// L4ServiceAdded increments the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// L4ServiceRemoved decrements the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// TCPRelayStarted records a new TCP relay connection starting. +func (m *Metrics) TCPRelayStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration. +func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct)) + m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct)) + m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend"))) + m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client"))) +} + +// TCPRelayDialError records a dial failure for a TCP relay. +func (m *Metrics) TCPRelayDialError(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// TCPRelayRejected records a rejected TCP relay (semaphore full). +func (m *Metrics) TCPRelayRejected(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPSessionStarted records a new UDP session starting. +func (m *Metrics) UDPSessionStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// UDPSessionEnded records a UDP session ending. +func (m *Metrics) UDPSessionEnded(accountID types.AccountID) { + m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID)))) +} + +// UDPSessionDialError records a dial failure for a UDP session. +func (m *Metrics) UDPSessionDialError(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// UDPSessionRejected records a rejected UDP session (limit or rate limited). +func (m *Metrics) UDPSessionRejected(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes. +func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) { + dir := attribute.String("direction", string(direction)) + m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir)) + m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir)) +} diff --git a/proxy/internal/netutil/errors.go b/proxy/internal/netutil/errors.go new file mode 100644 index 000000000..ff24e33d4 --- /dev/null +++ b/proxy/internal/netutil/errors.go @@ -0,0 +1,40 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + "syscall" +) + +// ValidatePort converts an int32 proto port to uint16, returning an error +// if the value is out of the valid 1–65535 range. +func ValidatePort(port int32) (uint16, error) { + if port <= 0 || port > math.MaxUint16 { + return 0, fmt.Errorf("invalid port %d: must be 1–65535", port) + } + return uint16(port), nil +} + +// IsExpectedError returns true for errors that are normal during +// connection teardown and should not be logged as warnings. +func IsExpectedError(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} + +// IsTimeout checks whether the error is a network timeout. +func IsTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/proxy/internal/netutil/errors_test.go b/proxy/internal/netutil/errors_test.go new file mode 100644 index 000000000..7d6be10ff --- /dev/null +++ b/proxy/internal/netutil/errors_test.go @@ -0,0 +1,92 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + port int32 + want uint16 + wantErr bool + }{ + {"valid min", 1, 1, false}, + {"valid mid", 8080, 8080, false}, + {"valid max", 65535, 65535, false}, + {"zero", 0, 0, true}, + {"negative", -1, 0, true}, + {"too large", 65536, 0, true}, + {"way too large", 100000, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidatePort(tt.port) + if tt.wantErr { + assert.Error(t, err) + assert.Zero(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestIsExpectedError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net.ErrClosed", net.ErrClosed, true}, + {"context.Canceled", context.Canceled, true}, + {"io.EOF", io.EOF, true}, + {"ECONNRESET", syscall.ECONNRESET, true}, + {"EPIPE", syscall.EPIPE, true}, + {"ECONNABORTED", syscall.ECONNABORTED, true}, + {"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true}, + {"unexpected EOF", io.ErrUnexpectedEOF, false}, + {"generic error", errors.New("something"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsExpectedError(tt.err)) + }) + } +} + +type timeoutErr struct{ timeout bool } + +func (e *timeoutErr) Error() string { return "timeout" } +func (e *timeoutErr) Timeout() bool { return e.timeout } +func (e *timeoutErr) Temporary() bool { return false } + +func TestIsTimeout(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net timeout", &timeoutErr{timeout: true}, true}, + {"net non-timeout", &timeoutErr{timeout: false}, false}, + {"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true}, + {"generic error", errors.New("not a timeout"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTimeout(tt.err)) + }) + } +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 22ebbf371..4a61f6bcf 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "net/netip" "sync" "github.com/netbirdio/netbird/proxy/internal/types" @@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string { type CapturedData struct { mu sync.RWMutex RequestID string - ServiceId string + ServiceId types.ServiceID AccountId types.AccountID Origin ResponseOrigin - ClientIP string + ClientIP netip.Addr UserID string AuthMethod string } @@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string { } // SetServiceId safely sets the service ID -func (c *CapturedData) SetServiceId(serviceId string) { +func (c *CapturedData) SetServiceId(serviceId types.ServiceID) { c.mu.Lock() defer c.mu.Unlock() c.ServiceId = serviceId } // GetServiceId safely gets the service ID -func (c *CapturedData) GetServiceId() string { +func (c *CapturedData) GetServiceId() types.ServiceID { c.mu.RLock() defer c.mu.RUnlock() return c.ServiceId @@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin { } // SetClientIP safely sets the resolved client IP. -func (c *CapturedData) SetClientIP(ip string) { +func (c *CapturedData) SetClientIP(ip netip.Addr) { c.mu.Lock() defer c.mu.Unlock() c.ClientIP = ip } // GetClientIP safely gets the resolved client IP. -func (c *CapturedData) GetClientIP() string { +func (c *CapturedData) GetClientIP() netip.Addr { c.mu.RLock() defer c.mu.RUnlock() return c.ClientIP @@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData { return data } -func withServiceId(ctx context.Context, serviceId string) context.Context { +func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context { return context.WithValue(ctx, serviceIdKey, serviceId) } -func ServiceIdFromContext(ctx context.Context) string { +func ServiceIdFromContext(ctx context.Context) types.ServiceID { v := ctx.Value(serviceIdKey) - serviceId, ok := v.(string) + serviceId, ok := v.(types.ServiceID) if !ok { return "" } diff --git a/proxy/internal/proxy/proxy_bench_test.go b/proxy/internal/proxy/proxy_bench_test.go index 5af2167e6..b59ef75c0 100644 --- a/proxy/internal/proxy/proxy_bench_test.go +++ b/proxy/internal/proxy/proxy_bench_test.go @@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil) rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: map[string]*proxy.PathTarget{ @@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) { target = id } rp.AddMapping(proxy.Mapping{ - ID: id, + ID: types.ServiceID(id), AccountID: types.AccountID(rand.Text()), Host: host, Paths: map[string]*proxy.PathTarget{ @@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) { } } rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: paths, diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index b0001d5b9..1ee9b2a42 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" ) @@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = roundtrip.WithSkipTLSVerify(ctx) } if pt.RequestTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout) - defer cancel() + ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) } rewriteMatchedPath := result.matchedPath @@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Header.Set(k, v) } - clientIP := extractClientIP(r.In.RemoteAddr) + clientIP := extractHostIP(r.In.RemoteAddr) - if IsTrustedProxy(clientIP, p.trustedProxies) { + if isTrustedAddr(clientIP, p.trustedProxies) { p.setTrustedForwardingHeaders(r, clientIP) } else { p.setUntrustedForwardingHeaders(r, clientIP) @@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string { // setTrustedForwardingHeaders appends to the existing forwarding header chain // and preserves upstream-provided headers when the direct connection is from // a trusted proxy. -func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() + // Append the direct connection IP to the existing X-Forwarded-For chain. if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" { - r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP) + r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr) } else { - r.Out.Header.Set("X-Forwarded-For", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) } // Preserve upstream X-Real-IP if present; otherwise resolve through the chain. @@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli r.Out.Header.Set("X-Real-IP", realIP) } else { resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies) - r.Out.Header.Set("X-Real-IP", resolved) + r.Out.Header.Set("X-Real-IP", resolved.String()) } // Preserve upstream X-Forwarded-Host if present. @@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli // sets them fresh based on the direct connection. This is the default // behavior when no trusted proxies are configured or the direct connection // is from an untrusted source. -func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) - r.Out.Header.Set("X-Forwarded-For", clientIP) - r.Out.Header.Set("X-Real-IP", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) + r.Out.Header.Set("X-Real-IP", ipStr) r.Out.Header.Set("X-Forwarded-Host", r.In.Host) r.Out.Header.Set("X-Forwarded-Proto", proto) r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) @@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) { } } -// extractClientIP extracts the IP address from an http.Request.RemoteAddr -// which is always in host:port format. -func extractClientIP(remoteAddr string) string { - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return remoteAddr - } - return ip -} - // extractForwardedPort returns the port from the Host header if present, // otherwise defaults to the standard port for the resolved protocol. func extractForwardedPort(host, resolvedProto string) string { @@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { web.ServeErrorPage(w, r, code, title, message, requestID, status) } -// getClientIP retrieves the resolved client IP from context. +// getClientIP retrieves the resolved client IP string from context. func getClientIP(r *http.Request) string { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { - return capturedData.GetClientIP() + if ip := capturedData.GetClientIP(); ip.IsValid() { + return ip.String() + } } return "" } diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index be2fb9105..b05ead198 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { }) } -func TestExtractClientIP(t *testing.T) { +func TestExtractHostIP(t *testing.T) { tests := []struct { name string remoteAddr string - expected string + expected netip.Addr }{ - {"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"}, - {"IPv6 with port", "[::1]:12345", "::1"}, - {"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"}, - {"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"}, - {"IPv6 without brackets fallback", "::1", "::1"}, - {"empty string fallback", "", ""}, - {"public IP", "203.0.113.50:9999", "203.0.113.50"}, + {"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")}, + {"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")}, + {"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")}, + {"empty string fallback", "", netip.Addr{}}, + {"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr)) + assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr)) }) } } diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 58b92ff9e..1513fbe45 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -30,8 +30,9 @@ type PathTarget struct { CustomHeaders map[string]string } +// Mapping describes how a domain is routed by the HTTP reverse proxy. type Mapping struct { - ID string + ID types.ServiceID AccountID types.AccountID Host string Paths map[string]*PathTarget @@ -42,7 +43,7 @@ type Mapping struct { type targetResult struct { target *PathTarget matchedPath string - serviceID string + serviceID types.ServiceID accountID types.AccountID passHostHeader bool rewriteRedirects bool @@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) { p.mappings[m.Host] = m } -func (p *ReverseProxy) RemoveMapping(m Mapping) { +// RemoveMapping removes the mapping for the given host and reports whether it existed. +func (p *ReverseProxy) RemoveMapping(m Mapping) bool { p.mappingsMux.Lock() defer p.mappingsMux.Unlock() + if _, ok := p.mappings[m.Host]; !ok { + return false + } delete(p.mappings, m.Host) + return true } diff --git a/proxy/internal/proxy/trustedproxy.go b/proxy/internal/proxy/trustedproxy.go index ad9a5b6c0..0fe693f90 100644 --- a/proxy/internal/proxy/trustedproxy.go +++ b/proxy/internal/proxy/trustedproxy.go @@ -7,21 +7,11 @@ import ( // IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes. func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { - if len(trusted) == 0 { - return false - } - addr, err := netip.ParseAddr(ipStr) - if err != nil { + if err != nil || len(trusted) == 0 { return false } - - for _, prefix := range trusted { - if prefix.Contains(addr) { - return true - } - } - return false + return isTrustedAddr(addr.Unmap(), trusted) } // ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list. @@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { // // If the trusted list is empty or remoteAddr is not trusted, it returns the // remoteAddr IP directly (ignoring any forwarding headers). -func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { - remoteIP := extractClientIP(remoteAddr) +func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr { + remoteIP := extractHostIP(remoteAddr) - if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) { + if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) { return remoteIP } @@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { if ip == "" { continue } - if !IsTrustedProxy(ip, trusted) { - return ip + addr, err := netip.ParseAddr(ip) + if err != nil { + continue + } + addr = addr.Unmap() + if !isTrustedAddr(addr, trusted) { + return addr } } // All IPs in XFF are trusted; return the leftmost as best guess. if first := strings.TrimSpace(parts[0]); first != "" { - return first + if addr, err := netip.ParseAddr(first); err == nil { + return addr.Unmap() + } } return remoteIP } + +// extractHostIP parses the IP from a host:port string and returns it unmapped. +func extractHostIP(hostPort string) netip.Addr { + if ap, err := netip.ParseAddrPort(hostPort); err == nil { + return ap.Addr().Unmap() + } + if addr, err := netip.ParseAddr(hostPort); err == nil { + return addr.Unmap() + } + return netip.Addr{} +} + +// isTrustedAddr checks if the given address falls within any of the trusted prefixes. +func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool { + if !addr.IsValid() { + return false + } + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} diff --git a/proxy/internal/proxy/trustedproxy_test.go b/proxy/internal/proxy/trustedproxy_test.go index 827b7babf..35ed1f5c2 100644 --- a/proxy/internal/proxy/trustedproxy_test.go +++ b/proxy/internal/proxy/trustedproxy_test.go @@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) { remoteAddr string xff string trusted []netip.Prefix - want string + want netip.Addr }{ { name: "empty trusted list returns RemoteAddr", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4", trusted: nil, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "untrusted RemoteAddr ignores XFF", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4, 10.0.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with single client in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr walks past trusted entries in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50, 10.0.0.2, 172.16.0.5", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", remoteAddr: "10.0.0.1:5000", xff: "", trusted: trusted, - want: "10.0.0.1", + want: netip.MustParseAddr("10.0.0.1"), }, { name: "all XFF IPs trusted returns leftmost", remoteAddr: "10.0.0.1:5000", xff: "10.0.0.2, 172.16.0.1, 10.0.0.3", trusted: trusted, - want: "10.0.0.2", + want: netip.MustParseAddr("10.0.0.2"), }, { name: "XFF with whitespace", remoteAddr: "10.0.0.1:5000", xff: " 203.0.113.50 , 10.0.0.2 ", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "XFF with empty segments", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50,,10.0.0.2", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "multi-hop with mixed trust", remoteAddr: "10.0.0.1:5000", xff: "8.8.8.8, 203.0.113.50, 172.16.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "RemoteAddr without port", remoteAddr: "10.0.0.1", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, } for _, tt := range tests { diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 57770f4a5..e38e3dc4e 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -14,11 +15,12 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/embed" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -26,7 +28,22 @@ import ( const deviceNamePrefix = "ingress-proxy-" // backendKey identifies a backend by its host:port from the target URL. -type backendKey = string +type backendKey string + +// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service) +// that holds a reference to an embedded NetBird client. Callers should use the +// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions. +type ServiceKey string + +// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service. +func DomainServiceKey(domain string) ServiceKey { + return ServiceKey("domain:" + domain) +} + +// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP). +func L4ServiceKey(id types.ServiceID) ServiceKey { + return ServiceKey("l4:" + id) +} var ( // ErrNoAccountID is returned when a request context is missing the account ID. @@ -39,24 +56,24 @@ var ( ErrTooManyInflight = errors.New("too many in-flight requests") ) -// domainInfo holds metadata about a registered domain. -type domainInfo struct { - serviceID string +// serviceInfo holds metadata about a registered service. +type serviceInfo struct { + serviceID types.ServiceID } -type domainNotification struct { - domain domain.Domain - serviceID string +type serviceNotification struct { + key ServiceKey + serviceID types.ServiceID } -// clientEntry holds an embedded NetBird client and tracks which domains use it. +// clientEntry holds an embedded NetBird client and tracks which services use it. type clientEntry struct { client *embed.Client transport *http.Transport // insecureTransport is a clone of transport with TLS verification disabled, // used when per-target skip_tls_verify is set. insecureTransport *http.Transport - domains map[domain.Domain]domainInfo + services map[ServiceKey]serviceInfo createdAt time.Time started bool // Per-backend in-flight limiting keyed by target host:port. @@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo // ClientConfig holds configuration for the embedded NetBird client. type ClientConfig struct { MgmtAddr string - WGPort int + WGPort uint16 PreSharedKey string } type statusNotifier interface { - NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error + NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error } type managementClient interface { @@ -107,7 +124,7 @@ type managementClient interface { // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. -// Clients are keyed by AccountID, allowing multiple domains to share the same connection. +// Clients are keyed by AccountID, allowing multiple services to share the same connection. type NetBird struct { proxyID string proxyAddr string @@ -124,11 +141,11 @@ type NetBird struct { // ClientDebugInfo contains debug information about a client. type ClientDebugInfo struct { - AccountID types.AccountID - DomainCount int - Domains domain.List - HasClient bool - CreatedAt time.Time + AccountID types.AccountID + ServiceCount int + ServiceKeys []string + HasClient bool + CreatedAt time.Time } // accountIDContextKey is the context key for storing the account ID. @@ -137,37 +154,37 @@ type accountIDContextKey struct{} // skipTLSVerifyContextKey is the context key for requesting insecure TLS. type skipTLSVerifyContextKey struct{} -// AddPeer registers a domain for an account. If the account doesn't have a client yet, +// AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. -// Multiple domains can share the same client. -func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error { +// Multiple services can share the same client. +func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { + si := serviceInfo{serviceID: serviceID} + n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { - // Client already exists for this account, just register the domain - entry.domains[d] = domainInfo{serviceID: serviceID} + entry.services[key] = si started := entry.started n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("registered domain with existing client") + "account_id": accountID, + "service_key": key, + }).Debug("registered service with existing client") - // If client is already started, notify this domain as connected immediately if started && n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).WithError(err).Warn("failed to notify status for existing client") } } return nil } - entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID) + entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if err != nil { n.clientsMux.Unlock() return err @@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).Info("created new client for account") // Attempt to start the client in the background; if this fails we will @@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma // createClientEntry generates a WireGuard keypair, authenticates with management, // and creates an embedded NetBird client. Must be called with clientsMux held. -func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) { +func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { + serviceID := si.serviceID n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, @@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }).Debug("authenticating new proxy peer with management") resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ - ServiceId: serviceID, + ServiceId: string(serviceID), AccountId: string(accountID), Token: authToken, WireguardPublicKey: publicKey.String(), @@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. + wgPort := int(n.clientCfg.WGPort) client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.clientCfg.WGPort, + WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { @@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. transport := &http.Transport{ - DialContext: client.DialContext, + DialContext: dialWithTimeout(client.DialContext), ForceAttemptHTTP2: true, MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, @@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account return &clientEntry{ client: client, - domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}}, + services: map[ServiceKey]serviceInfo{key: si}, transport: transport, insecureTransport: insecureTransport, createdAt: time.Now(), @@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }, nil } -// runClientStartup starts the client and notifies registered domains on success. +// runClientStartup starts the client and notifies registered services on success. func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI return } - // Mark client as started and collect domains to notify outside the lock. + // Mark client as started and collect services to notify outside the lock. n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.started = true } - var domainsToNotify []domainNotification + var toNotify []serviceNotification if exists { - for dom, info := range entry.domains { - domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID}) + for key, info := range entry.services { + toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } n.clientsMux.Unlock() @@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI if n.statusNotifier == nil { return } - for _, dn := range domainsToNotify { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil { + for _, sn := range toNotify { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).WithError(err).Warn("failed to notify tunnel connection status") } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).Info("notified management about tunnel connection") } } } -// RemovePeer unregisters a domain from an account. The client is only stopped -// when no domains are using it anymore. -func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error { +// RemovePeer unregisters a service from an account. The client is only stopped +// when no services are using it anymore. +func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error { n.clientsMux.Lock() entry, exists := n.clients[accountID] @@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d return nil } - // Get domain info before deleting - domInfo, domainExists := entry.domains[d] - if !domainExists { + si, svcExists := entry.services[key] + if !svcExists { n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("remove peer: domain not registered") + "account_id": accountID, + "service_key": key, + }).Debug("remove peer: service not registered") return nil } - delete(entry.domains, d) - - // If there are still domains using this client, keep it running - if len(entry.domains) > 0 { - n.clientsMux.Unlock() + delete(entry.services, key) + stopClient := len(entry.services) == 0 + var client *embed.Client + var transport, insecureTransport *http.Transport + if stopClient { + n.logger.WithField("account_id", accountID).Info("stopping client, no more services") + client = entry.client + transport = entry.transport + insecureTransport = entry.insecureTransport + delete(n.clients, accountID) + } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - "remaining_domains": len(entry.domains), - }).Debug("unregistered domain, client still in use") - - // Notify this domain as disconnected - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") - } - } - return nil + "account_id": accountID, + "service_key": key, + "remaining_services": len(entry.services), + }).Debug("unregistered service, client still in use") } - - // No more domains using this client, stop it - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).Info("stopping client, no more domains") - - client := entry.client - transport := entry.transport - insecureTransport := entry.insecureTransport - delete(n.clients, accountID) n.clientsMux.Unlock() - // Notify disconnection before stopping - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") + n.notifyDisconnect(ctx, accountID, key, si.serviceID) + + if stopClient { + transport.CloseIdleConnections() + insecureTransport.CloseIdleConnections() + if err := client.Stop(ctx); err != nil { + n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client") } } - transport.CloseIdleConnections() - insecureTransport.CloseIdleConnections() - - if err := client.Stop(ctx); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).WithError(err).Warn("failed to stop netbird client") - } - return nil } +func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) { + if n.statusNotifier == nil { + return + } + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil { + if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound { + n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "service_key": key, + }).WithError(err).Warn("failed to notify tunnel disconnection status") + } + } +} + // RoundTrip implements http.RoundTripper. It looks up the client for the account // specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { @@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { } n.clientsMux.RUnlock() - release, ok := entry.acquireInflight(req.URL.Host) + release, ok := entry.acquireInflight(backendKey(req.URL.Host)) defer release() if !ok { return nil, ErrTooManyInflight @@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool { return exists } -// DomainCount returns the number of domains registered for the given account. +// ServiceCount returns the number of services registered for the given account. // Returns 0 if the account has no client. -func (n *NetBird) DomainCount(accountID types.AccountID) int { +func (n *NetBird) ServiceCount(accountID types.AccountID) int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return 0 } - return len(entry.domains) + return len(entry.services) } // ClientCount returns the total number of active clients. @@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { result := make(map[types.AccountID]ClientDebugInfo) for accountID, entry := range n.clients { - domains := make(domain.List, 0, len(entry.domains)) - for d := range entry.domains { - domains = append(domains, d) + keys := make([]string, 0, len(entry.services)) + for k := range entry.services { + keys = append(keys, string(k)) } result[accountID] = ClientDebugInfo{ - AccountID: accountID, - DomainCount: len(entry.domains), - Domains: domains, - HasClient: entry.client != nil, - CreatedAt: entry.createdAt, + AccountID: accountID, + ServiceCount: len(entry.services), + ServiceKeys: keys, + HasClient: entry.client != nil, + CreatedAt: entry.createdAt, } } return result @@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L } } +// dialWithTimeout wraps a DialContext function so that any dial timeout +// stored in the context (via types.WithDialTimeout) is applied only to +// the connection establishment phase, not the full request lifetime. +func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if d, ok := types.DialTimeoutFromContext(ctx); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + return dial(ctx, network, addr) + } +} + // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { return context.WithValue(ctx, accountIDContextKey{}, accountID) diff --git a/proxy/internal/roundtrip/netbird_bench_test.go b/proxy/internal/roundtrip/netbird_bench_test.go index e89213c33..330ea0332 100644 --- a/proxy/internal/roundtrip/netbird_bench_test.go +++ b/proxy/internal/roundtrip/netbird_bench_test.go @@ -1,6 +1,7 @@ package roundtrip import ( + "context" "crypto/rand" "math/big" "sync" @@ -8,7 +9,6 @@ import ( "time" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" ) // Simple benchmark for comparison with AddPeer contention. @@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } // Launch workers that continuously call AddPeer with new random accountIDs. + ctx, cancel := context.WithCancel(b.Context()) var wg sync.WaitGroup for range addPeerWorkers { - wg.Go(func() { - for { - if err := nb.AddPeer(b.Context(), + wg.Add(1) + go func() { + defer wg.Done() + for ctx.Err() == nil { + if err := nb.AddPeer(ctx, types.AccountID(rand.Text()), - domain.Domain(rand.Text()), + ServiceKey(rand.Text()), rand.Text(), - rand.Text()); err != nil { - b.Log(err) + types.ServiceID(rand.Text())); err != nil { + return } } - }) + }() } // Benchmark calling HasClient during AddPeer contention. @@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } }) b.StopTimer() + cancel() + wg.Wait() } diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 0a742c2fa..5444f6c11 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -27,16 +26,15 @@ type mockStatusNotifier struct { } type statusCall struct { - accountID string - serviceID string - domain string + accountID types.AccountID + serviceID types.ServiceID connected bool } -func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error { +func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { m.mu.Lock() defer m.mu.Unlock() - m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected}) + m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected}) return nil } @@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { // Initially no client exists. assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") - // Add first domain - this should create a new client. - // Note: This will fail to actually connect since we use an invalid URL, - // but the client entry should still be created. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service - this should create a new client. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID), "should have client after AddPeer") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add first domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - assert.Equal(t, 1, nb.DomainCount(accountID)) + assert.Equal(t, 1, nb.ServiceCount(accountID)) - // Add second domain for the same account - should reuse existing client. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + // Add second service for the same account - should reuse existing client. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service") - // Add third domain. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + // Add third service. + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain") + assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service") // Still only one client. assert.True(t, nb.HasClient(accountID)) @@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) { account1 := types.AccountID("account-1") account2 := types.AccountID("account-2") - // Add domain for account 1. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add service for account 1. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Add domain for account 2. - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2") + // Add service for account 2. + err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2")) require.NoError(t, err) // Both accounts should have their own clients. assert.True(t, nb.HasClient(account1), "account1 should have client") assert.True(t, nb.HasClient(account2), "account2 should have client") - assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1") - assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1") } -func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) { +func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add multiple domains. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add multiple services. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID)) + assert.Equal(t, 3, nb.ServiceCount(accountID)) - // Remove one domain - client should remain. + // Remove one service - client should remain. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain") - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2") + assert.True(t, nb.HasClient(accountID), "client should remain after removing one service") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2") - // Remove another domain - client should still remain. + // Remove another service - client should still remain. err = nb.RemovePeer(context.Background(), accountID, "domain2.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.True(t, nb.HasClient(accountID), "client should remain after removing second service") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } -func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) { +func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add single domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add single service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) - // Remove the only domain - client should be removed. - // Note: Stop() may fail since the client never actually connected, - // but the entry should still be removed from the map. + // Remove the only service - client should be removed. _ = nb.RemovePeer(context.Background(), accountID, "domain1.test") - // After removing all domains, client should be gone. - assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + // After removing all services, client should be gone. + assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") } func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { @@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { assert.NoError(t, err, "removing from non-existent account should not error") } -func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) { +func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add one domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add one service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Remove non-existent domain - should not affect existing domain. - err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test")) + // Remove non-existent service - should not affect existing service. + err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test") require.NoError(t, err) - // Original domain should still be registered. + // Original service should still be registered. assert.True(t, nb.HasClient(accountID)) - assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain") + assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain") } func TestWithAccountID_AndAccountIDFromContext(t *testing.T) { @@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) { account2 := types.AccountID("account-2") account3 := types.AccountID("account-3") - // Add domains for multiple accounts. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1") + // Add services for multiple accounts. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3") + err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3")) require.NoError(t, err) assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients") // Stop all clients. - // Note: StopAll may return errors since clients never actually connected, - // but the clients should still be removed from the map. _ = nb.StopAll(context.Background()) assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll") @@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) { assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients") // Add clients for different accounts. - err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1") + err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.Equal(t, 1, nb.ClientCount()) - err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) assert.Equal(t, 2, nb.ClientCount()) - // Adding domain to existing account should not increase count. - err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b") + // Adding service to existing account should not increase count. + err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b")) require.NoError(t, err) - assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count") + assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count") } func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) { @@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - // Add first domain — creates a new client entry. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + // Add first service — creates a new client entry. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) // Manually mark client as started to simulate background startup completing. @@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { nb.clients[accountID].started = true nb.clientsMux.Unlock() - // Add second domain — should notify immediately since client is already started. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + // Add second service — should notify immediately since client is already started. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, string(accountID), calls[0].accountID) - assert.Equal(t, "svc-2", calls[0].serviceID) - assert.Equal(t, "domain2.test", calls[0].domain) + assert.Equal(t, accountID, calls[0].accountID) + assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID) assert.True(t, calls[0].connected) } @@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) - // Remove one domain — client stays, but disconnection notification fires. + // Remove one service — client stays, but disconnection notification fires. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, "domain1.test", calls[0].domain) + assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID) assert.False(t, calls[0].connected) } diff --git a/proxy/internal/tcp/bench_test.go b/proxy/internal/tcp/bench_test.go new file mode 100644 index 000000000..049f8395d --- /dev/null +++ b/proxy/internal/tcp/bench_test.go @@ -0,0 +1,133 @@ +package tcp + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "testing" +) + +// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real +// TLS ClientHello and extracting the SNI. This is the per-connection cost +// added to every TLS connection on the main listener. +func BenchmarkPeekClientHello_TLS(b *testing.B) { + // Pre-generate a ClientHello by capturing what crypto/tls sends. + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + var hello []byte + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + hello = make([]byte, n) + copy(hello, buf[:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(hello) + conn := &readerConn{Reader: r} + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + // Simulate draining the peeked bytes (what the HTTP server would do). + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS +// connections that hit the fast non-handshake exit path. +func BenchmarkPeekClientHello_NonTLS(b *testing.B) { + httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(httpReq) + conn := &readerConn{Reader: r} + _, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn +// wrapper compared to a plain connection read. The peeked bytes use +// io.MultiReader which adds one indirection per Read call. +func BenchmarkPeekedConn_Read(b *testing.B) { + data := make([]byte, 4096) + peeked := make([]byte, 512) + buf := make([]byte, 1024) + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(data) + conn := &readerConn{Reader: r} + pc := newPeekedConn(conn, peeked) + for { + _, err := pc.Read(buf) + if err != nil { + break + } + } + } +} + +// BenchmarkExtractSNI measures just the in-memory SNI parsing cost, +// excluding I/O. +func BenchmarkExtractSNI(b *testing.B) { + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + payload := make([]byte, n-tlsRecordHeaderLen) + copy(payload, buf[tlsRecordHeaderLen:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + sni := extractSNI(payload) + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + } +} + +// readerConn wraps an io.Reader as a net.Conn for benchmarking. +// Only Read is functional; all other methods are no-ops. +type readerConn struct { + io.Reader + net.Conn +} + +func (c *readerConn) Read(b []byte) (int, error) { + return c.Reader.Read(b) +} diff --git a/proxy/internal/tcp/chanlistener.go b/proxy/internal/tcp/chanlistener.go new file mode 100644 index 000000000..ee64bc0a2 --- /dev/null +++ b/proxy/internal/tcp/chanlistener.go @@ -0,0 +1,76 @@ +package tcp + +import ( + "net" + "sync" +) + +// chanListener implements net.Listener by reading connections from a channel. +// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS. +type chanListener struct { + ch chan net.Conn + addr net.Addr + once sync.Once + closed chan struct{} +} + +func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener { + return &chanListener{ + ch: ch, + addr: addr, + closed: make(chan struct{}), + } +} + +// Accept waits for and returns the next connection from the channel. +func (l *chanListener) Accept() (net.Conn, error) { + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + return conn, nil + case <-l.closed: + // Drain buffered connections before returning. + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + _ = conn.Close() + default: + return nil, net.ErrClosed + } + } + } + } +} + +// Close signals the listener to stop accepting connections and drains +// any buffered connections that have not yet been accepted. +func (l *chanListener) Close() error { + l.once.Do(func() { + close(l.closed) + for { + select { + case conn, ok := <-l.ch: + if !ok { + return + } + _ = conn.Close() + default: + return + } + } + }) + return nil +} + +// Addr returns the listener's network address. +func (l *chanListener) Addr() net.Addr { + return l.addr +} + +var _ net.Listener = (*chanListener)(nil) diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go new file mode 100644 index 000000000..26f3e5c7c --- /dev/null +++ b/proxy/internal/tcp/peekedconn.go @@ -0,0 +1,39 @@ +package tcp + +import ( + "bytes" + "io" + "net" +) + +// peekedConn wraps a net.Conn and prepends previously peeked bytes +// so that readers see the full original stream transparently. +type peekedConn struct { + net.Conn + reader io.Reader +} + +func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn { + return &peekedConn{ + Conn: conn, + reader: io.MultiReader(bytes.NewReader(peeked), conn), + } +} + +// Read replays the peeked bytes first, then reads from the underlying conn. +func (c *peekedConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// CloseWrite delegates to the underlying connection if it supports +// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn +// as an interface hides the concrete type's CloseWrite method, making +// half-close a silent no-op for all SNI-routed connections. +func (c *peekedConn) CloseWrite() error { + if hc, ok := c.Conn.(halfCloser); ok { + return hc.CloseWrite() + } + return nil +} + +var _ halfCloser = (*peekedConn)(nil) diff --git a/proxy/internal/tcp/proxyprotocol.go b/proxy/internal/tcp/proxyprotocol.go new file mode 100644 index 000000000..699b75a5d --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol.go @@ -0,0 +1,29 @@ +package tcp + +import ( + "fmt" + "net" + + "github.com/pires/go-proxyproto" +) + +// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection, +// conveying the real client address. +func writeProxyProtoV2(client, backend net.Conn) error { + tp := proxyproto.TCPv4 + if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil { + tp = proxyproto.TCPv6 + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: tp, + SourceAddr: client.RemoteAddr(), + DestinationAddr: client.LocalAddr(), + } + if _, err := header.WriteTo(backend); err != nil { + return fmt.Errorf("write PROXY protocol v2 header: %w", err) + } + return nil +} diff --git a/proxy/internal/tcp/proxyprotocol_test.go b/proxy/internal/tcp/proxyprotocol_test.go new file mode 100644 index 000000000..f8c48b2ab --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol_test.go @@ -0,0 +1,128 @@ +package tcp + +import ( + "bufio" + "net" + "testing" + + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteProxyProtoV2_IPv4(t *testing.T) { + // Set up a real TCP listener and dial to get connections with real addresses. + ln, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp4", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + // Use a pipe as the backend: write the header to one end, read from the other. + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + // serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination. + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + // Read the PROXY protocol header from the backend read side. + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4") + + // serverConn.RemoteAddr() is the client's address (source in the header). + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + // serverConn.LocalAddr() is the server's address (destination in the header). + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} + +func TestWriteProxyProtoV2_IPv6(t *testing.T) { + // Set up a real TCP6 listener on loopback. + ln, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available:", err) + } + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp6", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6") + + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go new file mode 100644 index 000000000..39949818d --- /dev/null +++ b/proxy/internal/tcp/relay.go @@ -0,0 +1,156 @@ +package tcp + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +// errIdleTimeout is returned when a relay connection is closed due to inactivity. +var errIdleTimeout = errors.New("idle timeout") + +// DefaultIdleTimeout is the default idle timeout for TCP relay connections. +// A zero value disables idle timeout checking. +const DefaultIdleTimeout = 5 * time.Minute + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn). When one copy direction finishes, we signal +// EOF to the remote by closing the write side while keeping the read +// side open so the other direction can drain. +type halfCloser interface { + CloseWrite() error +} + +// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Relay copies data bidirectionally between src and dst until both +// sides are done or the context is canceled. When idleTimeout is +// non-zero, each direction's read is deadline-guarded; if no data +// flows within the timeout the connection is torn down. When one +// direction finishes, it half-closes the write side of the +// destination (if supported) to signal EOF, allowing the other +// direction to drain gracefully before the full connection teardown. +func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + _ = src.Close() + _ = dst.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var errSrcToDst, errDstToSrc error + + go func() { + defer wg.Done() + srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) + halfClose(dst) + cancel() + }() + + go func() { + defer wg.Done() + dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) + halfClose(src) + cancel() + }() + + wg.Wait() + + if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { + logger.Debug("relay closed due to idle timeout") + } + if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { + logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) + } + if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { + logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) + } + + return srcToDst, dstToSrc +} + +// copyWithIdleTimeout copies from src to dst using a pooled buffer. +// When idleTimeout > 0 it sets a read deadline on src before each +// read and treats a timeout as an idle-triggered close. +func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + if idleTimeout <= 0 { + return io.CopyBuffer(dst, src, *bufp) + } + + conn, ok := src.(net.Conn) + if !ok { + return io.CopyBuffer(dst, src, *bufp) + } + + buf := *bufp + var total int64 + for { + if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { + return total, err + } + nr, readErr := src.Read(buf) + if nr > 0 { + n, err := checkedWrite(dst, buf[:nr]) + total += n + if err != nil { + return total, err + } + } + if readErr != nil { + if netutil.IsTimeout(readErr) { + return total, errIdleTimeout + } + return total, readErr + } + } +} + +// checkedWrite writes buf to dst and returns the number of bytes written. +// It guards against short writes and negative counts per io.Copy convention. +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) +} + +// halfClose attempts to half-close the write side of the connection. +// If the connection does not support half-close, this is a no-op. +func halfClose(conn net.Conn) { + if hc, ok := conn.(halfCloser); ok { + // Best-effort; the full close will follow shortly. + _ = hc.CloseWrite() + } +} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go new file mode 100644 index 000000000..e42d65b9d --- /dev/null +++ b/proxy/internal/tcp/relay_test.go @@ -0,0 +1,210 @@ +package tcp + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +func TestRelay_BidirectionalCopy(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + srcData := []byte("hello from src") + dstData := []byte("hello from dst") + + // dst side: write response first, then read + close. + go func() { + _, _ = dstClient.Write(dstData) + buf := make([]byte, 256) + _, _ = dstClient.Read(buf) + dstClient.Close() + }() + + // src side: read the response, then send data + close. + go func() { + buf := make([]byte, 256) + _, _ = srcClient.Read(buf) + _, _ = srcClient.Write(srcData) + srcClient.Close() + }() + + s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + + assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") + assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") +} + +func TestRelay_ContextCancellation(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer srcClient.Close() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + // Cancel should cause Relay to return. + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after context cancellation") + } +} + +func TestRelay_OneSideClosed(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Close src immediately. Relay should complete without hanging. + srcClient.Close() + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after one side closed") + } +} + +func TestRelay_LargeTransfer(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // 1MB of data. + data := make([]byte, 1<<20) + for i := range data { + data[i] = byte(i % 256) + } + + go func() { + _, _ = srcClient.Write(data) + srcClient.Close() + }() + + errCh := make(chan error, 1) + go func() { + received, err := io.ReadAll(dstClient) + if err != nil { + errCh <- err + return + } + if len(received) != len(data) { + errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received)) + return + } + errCh <- nil + dstClient.Close() + }() + + s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") + require.NoError(t, <-errCh) +} + +func TestRelay_IdleTimeout(t *testing.T) { + // Use real TCP connections so SetReadDeadline works (net.Pipe + // does not support deadlines). + srcLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srcLn.Close() + + dstLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer dstLn.Close() + + srcClient, err := net.Dial("tcp", srcLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer srcClient.Close() + + srcServer, err := srcLn.Accept() + if err != nil { + t.Fatal(err) + } + + dstClient, err := net.Dial("tcp", dstLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer dstClient.Close() + + dstServer, err := dstLn.Accept() + if err != nil { + t.Fatal(err) + } + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Send initial data to prove the relay works. + go func() { + _, _ = srcClient.Write([]byte("ping")) + }() + + done := make(chan struct{}) + var s2d, d2s int64 + go func() { + s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + close(done) + }() + + // Read the forwarded data on the dst side. + buf := make([]byte, 64) + n, err := dstClient.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Now stop sending. The relay should close after the idle timeout. + select { + case <-done: + assert.Greater(t, s2d, int64(0), "should have transferred initial data") + _ = d2s + case <-time.After(5 * time.Second): + t.Fatal("Relay did not exit after idle timeout") + } +} + +func TestIsExpectedError(t *testing.T) { + assert.True(t, netutil.IsExpectedError(net.ErrClosed)) + assert.True(t, netutil.IsExpectedError(context.Canceled)) + assert.True(t, netutil.IsExpectedError(io.EOF)) + assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF)) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go new file mode 100644 index 000000000..84fde0731 --- /dev/null +++ b/proxy/internal/tcp/router.go @@ -0,0 +1,570 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// defaultDialTimeout is the fallback dial timeout when no per-route +// timeout is configured. +const defaultDialTimeout = 30 * time.Second + +// SNIHost is a typed key for SNI hostname lookups. +type SNIHost string + +// RouteType specifies how a connection should be handled. +type RouteType int + +const ( + // RouteHTTP routes the connection through the HTTP reverse proxy. + RouteHTTP RouteType = iota + // RouteTCP relays the connection directly to the backend (TLS passthrough). + RouteTCP +) + +const ( + // sniPeekTimeout is the deadline for reading the TLS ClientHello. + sniPeekTimeout = 5 * time.Second + // DefaultDrainTimeout is the default grace period for in-flight relay + // connections to finish during shutdown. + DefaultDrainTimeout = 30 * time.Second + // DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router. + DefaultMaxRelayConns = 4096 + // httpChannelBuffer is the capacity of the channel feeding HTTP connections. + httpChannelBuffer = 4096 +) + +// DialResolver returns a DialContextFunc for the given account. +type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error) + +// Route describes where a connection for a given SNI should be sent. +type Route struct { + Type RouteType + AccountID types.AccountID + ServiceID types.ServiceID + // Domain is the service's configured domain, used for access log entries. + Domain string + // Protocol is the frontend protocol (tcp, tls), used for access log entries. + Protocol accesslog.Protocol + // Target is the backend address for TCP relay (e.g. "10.0.0.5:5432"). + Target string + // ProxyProtocol enables sending a PROXY protocol v2 header to the backend. + ProxyProtocol bool + // DialTimeout overrides the default dial timeout for this route. + // Zero uses defaultDialTimeout. + DialTimeout time.Duration +} + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// RelayObserver receives callbacks for TCP relay lifecycle events. +// All methods must be safe for concurrent use. +type RelayObserver interface { + TCPRelayStarted(accountID types.AccountID) + TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) + TCPRelayDialError(accountID types.AccountID) + TCPRelayRejected(accountID types.AccountID) +} + +// Router accepts raw TCP connections on a shared listener, peeks at +// the TLS ClientHello to extract the SNI, and routes the connection +// to either the HTTP reverse proxy or a direct TCP relay. +type Router struct { + logger *log.Logger + // httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter. + httpCh chan net.Conn + httpListener *chanListener + mu sync.RWMutex + routes map[SNIHost][]Route + fallback *Route + draining bool + dialResolve DialResolver + activeConns sync.WaitGroup + activeRelays sync.WaitGroup + relaySem chan struct{} + drainDone chan struct{} + observer RelayObserver + accessLog l4Logger + // svcCtxs tracks a context per service ID. All relay goroutines for a + // service derive from its context; canceling it kills them immediately. + svcCtxs map[types.ServiceID]context.Context + svcCancels map[types.ServiceID]context.CancelFunc +} + +// NewRouter creates a new SNI-based connection router. +func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router { + httpCh := make(chan net.Conn, httpChannelBuffer) + return &Router{ + logger: logger, + httpCh: httpCh, + httpListener: newChanListener(httpCh, addr), + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// NewPortRouter creates a Router for a dedicated port without an HTTP +// channel. Connections that don't match any SNI route fall through to +// the fallback relay (if set) or are closed. +func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router { + return &Router{ + logger: logger, + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// HTTPListener returns a net.Listener that yields connections routed +// to the HTTP handler. Use this with http.Server.ServeTLS. +func (r *Router) HTTPListener() net.Listener { + return r.httpListener +} + +// AddRoute registers an SNI route. Multiple routes for the same host are +// stored and resolved by priority at lookup time (HTTP > TCP). +// Empty host is ignored to prevent conflicts with ECH/ESNI fallback. +func (r *Router) AddRoute(host SNIHost, route Route) { + if host == "" { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + routes := r.routes[host] + for i, existing := range routes { + if existing.ServiceID == route.ServiceID { + r.cancelServiceLocked(route.ServiceID) + routes[i] = route + return + } + } + r.routes[host] = append(routes, route) +} + +// RemoveRoute removes the route for the given host and service ID. +// Active relay connections for the service are closed immediately. +// If other routes remain for the host, they are preserved. +func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + + r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool { + return route.ServiceID == svcID + }) + if len(r.routes[host]) == 0 { + delete(r.routes, host) + } + r.cancelServiceLocked(svcID) +} + +// SetFallback registers a catch-all route for connections that don't +// match any SNI route. On a port router this handles plain TCP relay; +// on the main router it takes priority over the HTTP channel. +func (r *Router) SetFallback(route Route) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = &route +} + +// RemoveFallback clears the catch-all fallback route and closes any +// active relay connections for the given service. +func (r *Router) RemoveFallback(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = nil + r.cancelServiceLocked(svcID) +} + +// SetObserver sets the relay lifecycle observer. Must be called before Serve. +func (r *Router) SetObserver(obs RelayObserver) { + r.mu.Lock() + defer r.mu.Unlock() + r.observer = obs +} + +// SetAccessLogger sets the L4 access logger. Must be called before Serve. +func (r *Router) SetAccessLogger(l l4Logger) { + r.mu.Lock() + defer r.mu.Unlock() + r.accessLog = l +} + +// getObserver returns the current relay observer under the read lock. +func (r *Router) getObserver() RelayObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + +// IsEmpty returns true when the router has no SNI routes and no fallback. +func (r *Router) IsEmpty() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.fallback == nil +} + +// Serve accepts connections from ln and routes them based on SNI. +// It blocks until ctx is canceled or ln is closed, then drains +// active relay connections up to DefaultDrainTimeout. +func (r *Router) Serve(ctx context.Context, ln net.Listener) error { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + _ = ln.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + case <-done: + } + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ok := r.Drain(DefaultDrainTimeout); !ok { + r.logger.Warn("timed out waiting for connections to drain") + } + return nil + } + r.logger.Debugf("SNI router accept: %v", err) + continue + } + r.activeConns.Add(1) + go func() { + defer r.activeConns.Done() + r.handleConn(ctx, conn) + }() + } +} + +// handleConn peeks at the TLS ClientHello and routes the connection. +func (r *Router) handleConn(ctx context.Context, conn net.Conn) { + // Fast path: when no SNI routes and no HTTP channel exist (pure TCP + // fallback port), skip the TLS peek entirely to avoid read errors on + // non-TLS connections and reduce latency. + if r.isFallbackOnly() { + r.handleUnmatched(ctx, conn) + return + } + + if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil { + r.logger.Debugf("set SNI peek deadline: %v", err) + _ = conn.Close() + return + } + + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + r.logger.Debugf("SNI peek: %v", err) + if wrapped != nil { + r.handleUnmatched(ctx, wrapped) + } else { + _ = conn.Close() + } + return + } + + if err := wrapped.SetReadDeadline(time.Time{}); err != nil { + r.logger.Debugf("clear SNI peek deadline: %v", err) + _ = wrapped.Close() + return + } + + host := SNIHost(sni) + route, ok := r.lookupRoute(host) + if !ok { + r.handleUnmatched(ctx, wrapped) + return + } + + if route.Type == RouteHTTP { + r.sendToHTTP(wrapped) + return + } + + if err := r.relayTCP(ctx, wrapped, host, route); err != nil { + r.logger.WithFields(log.Fields{ + "sni": host, + "service_id": route.ServiceID, + "target": route.Target, + }).Warnf("TCP relay: %v", err) + _ = wrapped.Close() + } +} + +// isFallbackOnly returns true when the router has no SNI routes and no HTTP +// channel, meaning all connections should go directly to the fallback relay. +func (r *Router) isFallbackOnly() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.httpCh == nil +} + +// handleUnmatched routes a connection that didn't match any SNI route. +// This includes ECH/ESNI connections where the cleartext SNI is empty. +// It tries the fallback relay first, then the HTTP channel, and closes +// the connection if neither is available. +func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { + r.mu.RLock() + fb := r.fallback + r.mu.RUnlock() + + if fb != nil { + if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { + r.logger.WithFields(log.Fields{ + "service_id": fb.ServiceID, + "target": fb.Target, + }).Warnf("TCP relay (fallback): %v", err) + _ = conn.Close() + } + return + } + r.sendToHTTP(conn) +} + +// lookupRoute returns the highest-priority route for the given SNI host. +// HTTP routes take precedence over TCP routes. +func (r *Router) lookupRoute(host SNIHost) (Route, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + routes, ok := r.routes[host] + if !ok || len(routes) == 0 { + return Route{}, false + } + best := routes[0] + for _, route := range routes[1:] { + if route.Type < best.Type { + best = route + } + } + return best, true +} + +// sendToHTTP feeds the connection to the HTTP handler via the channel. +// If no HTTP channel is configured (port router), the router is +// draining, or the channel is full, the connection is closed. +func (r *Router) sendToHTTP(conn net.Conn) { + if r.httpCh == nil { + _ = conn.Close() + return + } + + r.mu.RLock() + draining := r.draining + r.mu.RUnlock() + + if draining { + _ = conn.Close() + return + } + + select { + case r.httpCh <- conn: + default: + r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr()) + _ = conn.Close() + } +} + +// Drain prevents new relay connections from starting and waits for all +// in-flight connection handlers and active relays to finish, up to the +// given timeout. Returns true if all completed, false on timeout. +func (r *Router) Drain(timeout time.Duration) bool { + r.mu.Lock() + r.draining = true + if r.drainDone == nil { + done := make(chan struct{}) + go func() { + r.activeConns.Wait() + r.activeRelays.Wait() + close(done) + }() + r.drainDone = done + } + done := r.drainDone + r.mu.Unlock() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// cancelServiceLocked cancels and removes the context for the given service, +// closing all its active relay connections. Must be called with mu held. +func (r *Router) cancelServiceLocked(svcID types.ServiceID) { + if cancel, ok := r.svcCancels[svcID]; ok { + cancel() + delete(r.svcCtxs, svcID) + delete(r.svcCancels, svcID) + } +} + +// relayTCP sets up and runs a bidirectional TCP relay. +// The caller owns conn and must close it if this method returns an error. +// On success (nil error), both conn and backend are closed by the relay. +func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { + svcCtx, err := r.acquireRelay(ctx, route) + if err != nil { + return err + } + defer func() { + <-r.relaySem + r.activeRelays.Done() + }() + + backend, err := r.dialBackend(svcCtx, route) + if err != nil { + obs := r.getObserver() + if obs != nil { + obs.TCPRelayDialError(route.AccountID) + } + return err + } + + if route.ProxyProtocol { + if err := writeProxyProtoV2(conn, backend); err != nil { + _ = backend.Close() + return fmt.Errorf("write PROXY protocol header: %w", err) + } + } + + obs := r.getObserver() + if obs != nil { + obs.TCPRelayStarted(route.AccountID) + } + + entry := r.logger.WithFields(log.Fields{ + "sni": sni, + "service_id": route.ServiceID, + "target": route.Target, + }) + entry.Debug("TCP relay started") + + start := time.Now() + s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout) + elapsed := time.Since(start) + + if obs != nil { + obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s) + } + entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s) + + r.logL4Entry(route, conn, elapsed, s2d, d2s) + return nil +} + +// acquireRelay checks draining state, increments activeRelays, and acquires +// a semaphore slot. Returns the per-service context on success. +// The caller must release the semaphore and call activeRelays.Done() when done. +func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) { + r.mu.Lock() + if r.draining { + r.mu.Unlock() + return nil, errors.New("router is draining") + } + r.activeRelays.Add(1) + svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID) + r.mu.Unlock() + + select { + case r.relaySem <- struct{}{}: + return svcCtx, nil + default: + r.activeRelays.Done() + obs := r.getObserver() + if obs != nil { + obs.TCPRelayRejected(route.AccountID) + } + return nil, errors.New("TCP relay connection limit reached") + } +} + +// dialBackend resolves the dialer for the route's account and dials the backend. +func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) { + dialFn, err := r.dialResolve(route.AccountID) + if err != nil { + return nil, fmt.Errorf("resolve dialer: %w", err) + } + + dialTimeout := route.DialTimeout + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout) + backend, err := dialFn(dialCtx, "tcp", route.Target) + dialCancel() + if err != nil { + return nil, fmt.Errorf("dial backend %s: %w", route.Target, err) + } + return backend, nil +} + +// logL4Entry sends a TCP relay access log entry if an access logger is configured. +func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + var sourceIP netip.Addr + if remote := conn.RemoteAddr(); remote != nil { + if ap, err := netip.ParseAddrPort(remote.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + } + + al.LogL4(accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DurationMs: duration.Milliseconds(), + BytesUpload: bytesUp, + BytesDownload: bytesDown, + }) +} + +// getOrCreateServiceCtxLocked returns the context for a service, creating one +// if it doesn't exist yet. The context is a child of the server context. +// Must be called with mu held. +func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context { + if ctx, ok := r.svcCtxs[svcID]; ok { + return ctx + } + ctx, cancel := context.WithCancel(parent) + r.svcCtxs[svcID] = ctx + r.svcCancels[svcID] = cancel + return ctx +} diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go new file mode 100644 index 000000000..0e2cfe3e1 --- /dev/null +++ b/proxy/internal/tcp/router_test.go @@ -0,0 +1,1670 @@ +package tcp + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRouter_HTTPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Dial in a goroutine. The TLS handshake will block since nothing + // completes it on the HTTP side, but we only care about routing. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + // Send a TLS ClientHello manually. + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + // Verify the connection was routed to the HTTP channel. + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("no connection received on HTTP channel") + } +} + +func TestRouter_TCPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Set up a TLS backend that the relay will connect to. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + backendAddr := backendLn.Addr().String() + + // Accept one connection on the backend, echo data back. + backendReady := make(chan struct{}) + go func() { + close(backendReady) + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + <-backendReady + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendAddr, + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Connect as a TLS client; the proxy should passthrough to the backend. + clientConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer clientConn.Close() + + testData := []byte("hello through TCP passthrough") + _, err = clientConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := clientConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through TCP passthrough") +} + +func TestRouter_UnknownSNIGoesToHTTP(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + // No routes registered. + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("unknown SNI should be routed to HTTP") + } +} + +// TestRouter_NonTLSConnectionDropped verifies that a non-TLS connection +// on the shared port is closed by the router (SNI peek fails to find a +// valid ClientHello, so there is no route match). +func TestRouter_NonTLSConnectionDropped(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Register a TLS passthrough route. Non-TLS should NOT match. + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: "127.0.0.1:9999", + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Send plain HTTP (non-TLS) data. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: tcp.example.com\r\n\r\n")) + + // Non-TLS traffic on a port with RouteTCP goes to the HTTP channel + // because there's no valid SNI to match. Verify it reaches HTTP. + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "non-TLS connection should fall through to HTTP") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("non-TLS connection was not routed to HTTP") + } +} + +// TestRouter_TLSAndHTTPCoexist verifies that a shared port with both HTTP +// and TLS passthrough routes correctly demuxes based on the SNI hostname. +func TestRouter_TLSAndHTTPCoexist(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes data. + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + // HTTP route. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // TLS passthrough route. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // 1. TLS connection with SNI "tcp.example.com" → TLS passthrough. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + testData := []byte("passthrough data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "TLS passthrough should relay data") + tlsConn.Close() + + // 2. TLS connection with SNI "app.example.com" → HTTP handler. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "HTTP SNI should go to HTTP handler") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } +} + +func TestRouter_AddRemoveRoute(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.1:5432"}) + + route, ok := router.lookupRoute("a.example.com") + assert.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type) + + route, ok = router.lookupRoute("b.example.com") + assert.True(t, ok) + assert.Equal(t, RouteTCP, route.Type) + + router.RemoveRoute("a.example.com", "svc-a") + _, ok = router.lookupRoute("a.example.com") + assert.False(t, ok) +} + +func TestChanListener_AcceptAndClose(t *testing.T) { + ch := make(chan net.Conn, 1) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + ln := newChanListener(ch, addr) + + assert.Equal(t, addr, ln.Addr()) + + // Send a connection. + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + ch <- serverConn + + conn, err := ln.Accept() + require.NoError(t, err) + assert.Equal(t, serverConn, conn) + + // Close should cause Accept to return error. + require.NoError(t, ln.Close()) + // Double close should be safe. + require.NoError(t, ln.Close()) + + _, err = ln.Accept() + assert.ErrorIs(t, err, net.ErrClosed) +} + +func TestRouter_HTTPPrecedenceGuard(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + host := SNIHost("app.example.com") + + t.Run("http takes precedence over tcp at lookup", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP route must take precedence over TCP") + assert.Equal(t, types.ServiceID("svc-http"), route.ServiceID) + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("tcp becomes active when http is removed", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + router.RemoveRoute(host, "svc-http") + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "TCP should take over after HTTP removal") + assert.Equal(t, types.ServiceID("svc-tcp"), route.ServiceID) + + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("order of add does not matter", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP takes precedence regardless of add order") + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("same service id updates in place", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.2:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, "10.0.0.2:443", route.Target, "route should be updated in place") + + router.RemoveRoute(host, "svc-1") + _, ok = router.lookupRoute(host) + assert.False(t, ok) + }) + + t.Run("double remove is safe", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-1"}) + router.RemoveRoute(host, "svc-1") + router.RemoveRoute(host, "svc-1") + + _, ok := router.lookupRoute(host) + assert.False(t, ok, "route should be gone after removal") + }) + + t.Run("remove does not affect other hosts", func(t *testing.T) { + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.2:22"}) + + router.RemoveRoute("a.example.com", "svc-a") + + _, ok := router.lookupRoute(SNIHost("a.example.com")) + assert.False(t, ok) + + route, ok := router.lookupRoute(SNIHost("b.example.com")) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "removing one host must not affect another") + + router.RemoveRoute("b.example.com", "svc-b") + }) +} + +func TestRouter_SetRemoveFallback(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new port router should be empty") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb", Target: "10.0.0.1:5432"}) + assert.False(t, router.IsEmpty(), "router with fallback should not be empty") + + router.AddRoute("a.example.com", Route{Type: RouteTCP, ServiceID: "svc-a", Target: "10.0.0.2:443"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb") + assert.False(t, router.IsEmpty(), "router with SNI route should not be empty") + + router.RemoveRoute("a.example.com", "svc-a") + assert.True(t, router.IsEmpty(), "router with no routes and no fallback should be empty") +} + +func TestPortRouter_FallbackRelaysData(t *testing.T) { + // Backend echo server. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Plain TCP (non-TLS) connection should be relayed via fallback. + // Use exactly 5 bytes. PeekClientHello reads 5 bytes as the TLS + // header, so a single 5-byte write lands as one chunk at the backend. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("hello") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through fallback relay") +} + +func TestPortRouter_FallbackOnUnknownSNI(t *testing.T) { + // Backend TLS echo server. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + // Only a fallback, no SNI route for "unknown.example.com". + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with unknown SNI → fallback relay to TLS backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("hello through fallback TLS") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "unknown SNI should relay through fallback") +} + +func TestPortRouter_SNIWinsOverFallback(t *testing.T) { + // Two backend echo servers: one for SNI match, one for fallback. + sniBacked := startEchoTLS(t) + fbBacked := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "sni-service", + Target: sniBacked.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBacked.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with matching SNI should go to SNI backend, not fallback. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("SNI route data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "SNI match should use SNI route, not fallback") +} + +func TestPortRouter_NoFallbackNoHTTP_Closes(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("hello")) + + // Connection should be closed by the router (no fallback, no HTTP). + buf := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be closed when no fallback and no HTTP channel") +} + +func TestRouter_FallbackAndHTTPCoexist(t *testing.T) { + // Fallback backend echo server (plain TCP). + fbBackend, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer fbBackend.Close() + + go func() { + conn, err := fbBackend.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + + // HTTP route for known SNI. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // Fallback for non-TLS / unknown SNI. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with known HTTP SNI → should go to HTTP channel. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "known HTTP SNI should go to HTTP channel") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } + + // 2. Plain TCP (non-TLS) → should go to fallback, not HTTP. + // Use exactly 5 bytes to match PeekClientHello header size. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("plain") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "non-TLS should be relayed via fallback, not HTTP") +} + +// startEchoTLS starts a TLS echo server and returns the listener. +func startEchoTLS(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + if _, err := conn.Write(buf[:n]); err != nil { + return + } + } + }() + + return ln +} + +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"tcp.example.com"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + } +} + +func TestRouter_DrainWaitsForRelays(t *testing.T) { + logger := log.StandardLogger() + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Accept connections: echo first message, then hold open until told to close. + closeBackend := make(chan struct{}) + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + <-closeBackend + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func() { + _ = router.Serve(ctx, ln) + close(serveDone) + }() + + // Open a relay connection (non-TLS, hits fallback). + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + _, _ = conn.Write([]byte("hello")) + + // Wait for the echo to confirm the relay is fully established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Drain with a short timeout should fail because the relay is still active. + assert.False(t, router.Drain(50*time.Millisecond), "drain should timeout with active relay") + + // Close backend connections so relays finish. + close(closeBackend) + _ = conn.Close() + + // Drain should now complete quickly. + assert.True(t, router.Drain(2*time.Second), "drain should succeed after relays end") + + cancel() + <-serveDone +} + +func TestRouter_DrainEmptyReturnsImmediately(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + start := time.Now() + ok := router.Drain(5 * time.Second) + elapsed := time.Since(start) + + assert.True(t, ok) + assert.Less(t, elapsed, 100*time.Millisecond, "drain with no relays should return immediately") +} + +// TestRemoveRoute_KillsActiveRelays verifies that removing a route +// immediately kills active relay connections for that service. +func TestRemoveRoute_KillsActiveRelays(t *testing.T) { + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes first message, then holds connection open. + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + // Hold the connection open. + for { + if _, err := c.Read(buf); err != nil { + return + } + } + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-1", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a relay connection. + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("hello")) + require.NoError(t, err) + + // Wait for echo to confirm relay is established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Remove the fallback: should kill the active relay. + router.RemoveFallback("svc-1") + + // The client connection should see an error (server closed). + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be killed after service removal") +} + +// TestRemoveRoute_KillsSNIRelays verifies that removing an SNI route +// kills its active relays without affecting other services. +func TestRemoveRoute_KillsSNIRelays(t *testing.T) { + backend := startEchoTLS(t) + defer backend.Close() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tls.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-tls", + Target: backend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a TLS relay. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + defer tlsConn.Close() + + _, err = tlsConn.Write([]byte("ping")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Remove the route: active relay should die. + router.RemoveRoute("tls.example.com", "svc-tls") + + _ = tlsConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = tlsConn.Read(buf) + assert.Error(t, err, "TLS relay should be killed after route removal") +} + +// TestPortRouter_SNIAndTCPFallbackCoexist verifies that a single port can +// serve both SNI-routed TLS passthrough and plain TCP fallback simultaneously. +func TestPortRouter_SNIAndTCPFallbackCoexist(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + // SNI route for a specific domain. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-sni", + Target: sniBackend.Addr().String(), + }) + // TCP fallback for everything else. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with matching SNI → goes to SNI backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = tlsConn.Write([]byte("sni-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-data", string(buf[:n]), "SNI match → SNI backend") + tlsConn.Close() + + // 2. Plain TCP (no TLS) → goes to fallback. + tcpConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + + _, err = tcpConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = tcpConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n]), "plain TCP → fallback backend") + tcpConn.Close() + + // 3. TLS with unknown SNI → also goes to fallback. + unknownBackend := startEchoTLS(t) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: unknownBackend.Addr().String(), + }) + + unknownTLS, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = unknownTLS.Write([]byte("unknown-sni")) + require.NoError(t, err) + n, err = unknownTLS.Read(buf) + require.NoError(t, err) + assert.Equal(t, "unknown-sni", string(buf[:n]), "unknown SNI → fallback backend") + unknownTLS.Close() +} + +// TestPortRouter_UpdateRouteSwapsSNI verifies that updating a route +// (remove + add with different target) correctly routes to the new backend. +func TestPortRouter_UpdateRouteSwapsSNI(t *testing.T) { + backend1 := startEchoTLS(t) + backend2 := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Initial route → backend1. + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend1.Addr().String(), + }) + + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("v1")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v1", string(buf[:n])) + conn1.Close() + + // Update: remove old route, add new → backend2. + router.RemoveRoute("db.example.com", "svc-db") + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend2.Addr().String(), + }) + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("v2")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v2", string(buf[:n])) + conn2.Close() +} + +// TestPortRouter_RemoveSNIFallsThrough verifies that after removing an +// SNI route, connections for that domain fall through to the fallback. +func TestPortRouter_RemoveSNIFallsThrough(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: sniBackend.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Before removal: SNI matches → sniBackend. + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("before")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "before", string(buf[:n])) + conn1.Close() + + // Remove SNI route. Should fall through to fallback. + router.RemoveRoute("db.example.com", "svc-db") + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("after")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n]), "after removal, should reach fallback") + conn2.Close() +} + +// TestPortRouter_RemoveFallbackCloses verifies that after removing the +// fallback, non-matching connections are closed. +func TestPortRouter_RemoveFallbackCloses(t *testing.T) { + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // With fallback: plain TCP works. + conn1, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = conn1.Write([]byte("hello")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + conn1.Close() + + // Remove fallback. + router.RemoveFallback("svc-fb") + + // Without fallback on a port router (no HTTP channel): connection should be closed. + conn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn2.Close() + _, _ = conn2.Write([]byte("bye")) + _ = conn2.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn2.Read(buf) + assert.Error(t, err, "without fallback, connection should be closed") +} + +// TestPortRouter_HTTPToTLSTransition verifies that switching a service from +// HTTP-only to TLS-only via remove+add doesn't orphan the old HTTP route. +func TestPortRouter_HTTPToTLSTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: HTTP-only. SNI connections go to HTTP channel. + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener did not receive connection for HTTP-only route") + } + + // Phase 2: Simulate update to TLS-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS connection should succeed after HTTP→TLS transition") + defer tlsConn2.Close() + + _, err = tlsConn2.Write([]byte("hello-tls")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello-tls", string(buf[:n]), "data should relay to TLS backend") +} + +// TestPortRouter_TLSToHTTPTransition verifies that switching a service from +// TLS-only to HTTP-only via remove+add doesn't orphan the old TLS route. +func TestPortRouter_TLSToHTTPTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: TLS-only. Route relays to backend. + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS relay should work before transition") + _, err = tlsConn.Write([]byte("tls-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "tls-data", string(buf[:n])) + tlsConn.Close() + + // Phase 2: Simulate update to HTTP-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + // TLS connection should now go to the HTTP listener, NOT to the old TLS backend. + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn2.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener should receive connection after TLS→HTTP transition") + } +} + +// TestPortRouter_MultiDomainSamePort verifies that two TLS services sharing +// the same port router are independently routable and removable. +func TestPortRouter_MultiDomainSamePort(t *testing.T) { + logger := log.StandardLogger() + backend1 := startEchoTLSMulti(t) + backend2 := startEchoTLSMulti(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + router.AddRoute("svc1.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: backend1.Addr().String()}) + router.AddRoute("svc2.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-2", Target: backend2.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // Both domains route independently. + for _, tc := range []struct { + sni string + data string + }{ + {"svc1.example.com", "hello-svc1"}, + {"svc2.example.com", "hello-svc2"}, + } { + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: tc.sni, InsecureSkipVerify: true}, + ) + require.NoError(t, err, "dial %s", tc.sni) + _, err = conn.Write([]byte(tc.data)) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, tc.data, string(buf[:n])) + conn.Close() + } + + // Remove svc1. Router should NOT be empty (svc2 still present). + router.RemoveRoute("svc1.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "router should not be empty with one route remaining") + + // svc2 still works. + conn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "svc2.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = conn2.Write([]byte("still-alive")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "still-alive", string(buf[:n])) + conn2.Close() + + // Remove svc2. Router is now empty. + router.RemoveRoute("svc2.example.com", "svc-2") + assert.True(t, router.IsEmpty(), "router should be empty after removing all routes") +} + +// TestPortRouter_SNIAndFallbackLifecycle verifies the full lifecycle of SNI +// routes and TCP fallback coexisting on the same port router, including the +// ordering of add/remove operations. +func TestPortRouter_SNIAndFallbackLifecycle(t *testing.T) { + logger := log.StandardLogger() + sniBackend := startEchoTLS(t) + fallbackBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Step 1: Add fallback first (port mapping), then SNI route (TLS service). + router.SetFallback(Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "pm-1", Target: fallbackBackend.Addr().String()}) + router.AddRoute("tls.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: sniBackend.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // SNI traffic goes to TLS backend. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = tlsConn.Write([]byte("sni-traffic")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-traffic", string(buf[:n])) + tlsConn.Close() + + // Plain TCP goes to fallback. + plainConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = plainConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = plainConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n])) + plainConn.Close() + + // Step 2: Remove SNI route. Fallback still works, router not empty. + router.RemoveRoute("tls.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "fallback still present") + + plainConn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + // Must send >= 5 bytes so the SNI peek completes immediately + // without waiting for the 5-second peek timeout. + _, err = plainConn2.Write([]byte("after")) + require.NoError(t, err) + n, err = plainConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n])) + plainConn2.Close() + + // Step 3: Remove fallback. Router is now empty. + router.RemoveFallback("pm-1") + assert.True(t, router.IsEmpty()) +} + +// TestPortRouter_IsEmptyTransitions verifies IsEmpty reflects correct state +// through all add/remove operations. +func TestPortRouter_IsEmptyTransitions(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new router") + + router.AddRoute("a.com", Route{Type: RouteTCP, ServiceID: "svc-a"}) + assert.False(t, router.IsEmpty(), "after adding route") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb1"}) + assert.False(t, router.IsEmpty(), "route + fallback") + + router.RemoveRoute("a.com", "svc-a") + assert.False(t, router.IsEmpty(), "fallback only") + + router.RemoveFallback("svc-fb1") + assert.True(t, router.IsEmpty(), "all removed") + + // Reverse order: fallback first, then route. + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb2"}) + assert.False(t, router.IsEmpty()) + + router.AddRoute("b.com", Route{Type: RouteTCP, ServiceID: "svc-b"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb2") + assert.False(t, router.IsEmpty(), "route still present") + + router.RemoveRoute("b.com", "svc-b") + assert.True(t, router.IsEmpty(), "fully empty again") +} + +// startEchoTLSMulti starts a TLS echo server that accepts multiple connections. +func startEchoTLSMulti(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} + +// startEchoPlain starts a plain TCP echo server that reads until newline +// or connection close, then echoes the received data. +func startEchoPlain(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + // Set a read deadline so we don't block forever waiting for more data. + _ = c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} diff --git a/proxy/internal/tcp/snipeek.go b/proxy/internal/tcp/snipeek.go new file mode 100644 index 000000000..25ab8e5ef --- /dev/null +++ b/proxy/internal/tcp/snipeek.go @@ -0,0 +1,191 @@ +package tcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" +) + +const ( + // TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2). + tlsRecordHeaderLen = 5 + // TLS handshake type for ClientHello. + handshakeTypeClientHello = 1 + // TLS ContentType for handshake messages. + contentTypeHandshake = 22 + // SNI extension type (RFC 6066). + extensionServerName = 0 + // SNI host name type. + sniHostNameType = 0 + // maxClientHelloLen caps the ClientHello size we're willing to buffer. + maxClientHelloLen = 16384 + // maxSNILen is the maximum valid DNS hostname length per RFC 1035. + maxSNILen = 253 +) + +// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI +// server name, and returns a wrapped connection that replays the peeked +// bytes transparently. If the data is not a valid TLS ClientHello or +// contains no SNI extension, sni is empty and err is nil. +// +// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the +// real server name is encrypted inside the encrypted_client_hello +// extension. This parser only reads the cleartext server_name extension +// (type 0x0000), so ECH connections return sni="" and are routed through +// the fallback path (or HTTP channel), which is the correct behavior +// for a transparent proxy that does not terminate TLS. +func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { + // Read the 5-byte TLS record header into a small stack-friendly buffer. + var header [tlsRecordHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return "", nil, fmt.Errorf("read TLS record header: %w", err) + } + + if header[0] != contentTypeHandshake { + return "", newPeekedConn(conn, header[:]), nil + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + if recordLen == 0 || recordLen > maxClientHelloLen { + return "", newPeekedConn(conn, header[:]), nil + } + + // Single allocation for header + payload. The peekedConn takes + // ownership of this buffer, so no further copies are needed. + buf := make([]byte, tlsRecordHeaderLen+recordLen) + copy(buf, header[:]) + + n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:]) + if err != nil { + return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err) + } + + sni = extractSNI(buf[tlsRecordHeaderLen:]) + return sni, newPeekedConn(conn, buf), nil +} + +// extractSNI parses a TLS handshake payload to find the SNI extension. +// Returns empty string if the payload is not a ClientHello or has no SNI. +func extractSNI(payload []byte) string { + if len(payload) < 4 { + return "" + } + + if payload[0] != handshakeTypeClientHello { + return "" + } + + // Handshake length (3 bytes, big-endian). + handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3]) + if handshakeLen > len(payload)-4 { + return "" + } + + return parseSNIFromClientHello(payload[4 : 4+handshakeLen]) +} + +// parseSNIFromClientHello walks the ClientHello message fields to reach +// the extensions block and extract the server_name extension value. +func parseSNIFromClientHello(msg []byte) string { + // ClientHello layout: + // ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id + if len(msg) < 34 { + return "" + } + + pos := 34 + + // Session ID (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + sessionIDLen := int(msg[pos]) + pos++ + pos += sessionIDLen + + // Cipher suites (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + cipherSuitesLen + + // Compression methods (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + compMethodsLen := int(msg[pos]) + pos++ + pos += compMethodsLen + + // Extensions (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + + extensionsEnd := pos + extensionsLen + if extensionsEnd > len(msg) { + return "" + } + + return findSNIExtension(msg[pos:extensionsEnd]) +} + +// findSNIExtension iterates over TLS extensions and returns the host +// name from the server_name extension, if present. +func findSNIExtension(extensions []byte) string { + pos := 0 + for pos+4 <= len(extensions) { + extType := binary.BigEndian.Uint16(extensions[pos : pos+2]) + extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4])) + pos += 4 + + if pos+extLen > len(extensions) { + return "" + } + + if extType == extensionServerName { + return parseSNIExtensionData(extensions[pos : pos+extLen]) + } + pos += extLen + } + return "" +} + +// parseSNIExtensionData parses the ServerNameList structure inside an +// SNI extension to extract the host name. +func parseSNIExtensionData(data []byte) string { + if len(data) < 2 { + return "" + } + listLen := int(binary.BigEndian.Uint16(data[0:2])) + if listLen > len(data)-2 { + return "" + } + + list := data[2 : 2+listLen] + pos := 0 + for pos+3 <= len(list) { + nameType := list[pos] + nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3])) + pos += 3 + + if pos+nameLen > len(list) { + return "" + } + + if nameType == sniHostNameType { + name := list[pos : pos+nameLen] + if nameLen > maxSNILen || bytes.ContainsRune(name, 0) { + return "" + } + return string(name) + } + pos += nameLen + } + return "" +} diff --git a/proxy/internal/tcp/snipeek_test.go b/proxy/internal/tcp/snipeek_test.go new file mode 100644 index 000000000..9afe6261d --- /dev/null +++ b/proxy/internal/tcp/snipeek_test.go @@ -0,0 +1,251 @@ +package tcp + +import ( + "crypto/tls" + "io" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPeekClientHello_ValidSNI(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + const expectedSNI = "example.com" + trailingData := []byte("trailing data after handshake") + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: expectedSNI, + InsecureSkipVerify: true, //nolint:gosec + }) + // The Handshake will send the ClientHello. It will fail because + // our server side isn't doing a real TLS handshake, but that's + // fine: we only need the ClientHello to be sent. + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello") + assert.NotNil(t, wrapped, "wrapped connection should not be nil") + + // Verify the wrapped connection replays the peeked bytes. + // Read the first 5 bytes (TLS record header) to confirm replay. + buf := make([]byte, 5) + n, err := wrapped.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type") + + // Write trailing data from the client side and verify it arrives + // through the wrapped connection after the peeked bytes. + go func() { + _, _ = clientConn.Write(trailingData) + }() + + // Drain the rest of the peeked ClientHello first. + peekedRest := make([]byte, 16384) + _, _ = wrapped.Read(peekedRest) + + got := make([]byte, len(trailingData)) + n, err = io.ReadFull(wrapped, got) + require.NoError(t, err) + assert.Equal(t, trailingData, got[:n]) +} + +func TestPeekClientHello_MultipleSNIs(t *testing.T) { + tests := []struct { + name string + serverName string + expectedSNI string + }{ + {"simple domain", "example.com", "example.com"}, + {"subdomain", "sub.example.com", "sub.example.com"}, + {"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: tt.serverName, + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, tt.expectedSNI, sni) + assert.NotNil(t, wrapped) + }) + } +} + +func TestPeekClientHello_NonTLSData(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // Send plain HTTP data (not TLS). + httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + go func() { + _, _ = clientConn.Write(httpData) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni, "should return empty SNI for non-TLS data") + assert.NotNil(t, wrapped) + + // Verify the wrapped connection still provides the original data. + buf := make([]byte, len(httpData)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data") +} + +func TestPeekClientHello_TruncatedHeader(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write only 3 bytes then close, fewer than the 5-byte TLS header. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01}) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated header") +} + +func TestPeekClientHello_TruncatedPayload(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write a valid TLS header claiming 100 bytes, but only send 10. + go func() { + header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed + _, _ = clientConn.Write(header) + _, _ = clientConn.Write(make([]byte, 10)) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated payload") +} + +func TestPeekClientHello_ZeroLengthRecord(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // TLS handshake header with zero-length payload. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00}) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni) + assert.NotNil(t, wrapped) +} + +func TestExtractSNI_InvalidPayload(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + {"nil", nil}, + {"empty", []byte{}}, + {"too short", []byte{0x01, 0x00}}, + {"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}}, + {"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Empty(t, extractSNI(tt.payload)) + }) + } +} + +func TestPeekedConn_CloseWrite(t *testing.T) { + t.Run("delegates to underlying TCPConn", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accepted <- c + } + }() + + client, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer client.Close() + + server := <-accepted + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + + // CloseWrite should succeed on a real TCP connection. + err = wrapped.CloseWrite() + assert.NoError(t, err) + + // The client should see EOF on reads after CloseWrite. + buf := make([]byte, 1) + _, err = client.Read(buf) + assert.Equal(t, io.EOF, err, "client should see EOF after half-close") + }) + + t.Run("no-op on non-halfcloser", func(t *testing.T) { + // net.Pipe does not implement CloseWrite. + _, server := net.Pipe() + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + err := wrapped.CloseWrite() + assert.NoError(t, err, "should be no-op on non-halfcloser") + }) +} + +func TestPeekedConn_ReplayAndPassthrough(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + peeked := []byte("peeked-data") + subsequent := []byte("subsequent-data") + + wrapped := newPeekedConn(serverConn, peeked) + + go func() { + _, _ = clientConn.Write(subsequent) + }() + + // Read should return peeked data first. + buf := make([]byte, len(peeked)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, peeked, buf[:n]) + + // Then subsequent data from the real connection. + buf = make([]byte, len(subsequent)) + n, err = io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, subsequent, buf[:n]) +} diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go index 41acfef40..bf3731803 100644 --- a/proxy/internal/types/types.go +++ b/proxy/internal/types/types.go @@ -1,5 +1,56 @@ // Package types defines common types used across the proxy package. package types +import ( + "context" + "net" + "time" +) + // AccountID represents a unique identifier for a NetBird account. type AccountID string + +// ServiceID represents a unique identifier for a proxy service. +type ServiceID string + +// ServiceMode describes how a reverse proxy service is exposed. +type ServiceMode string + +const ( + ServiceModeHTTP ServiceMode = "http" + ServiceModeTCP ServiceMode = "tcp" + ServiceModeUDP ServiceMode = "udp" + ServiceModeTLS ServiceMode = "tls" +) + +// IsL4 returns true for TCP, UDP, and TLS modes. +func (m ServiceMode) IsL4() bool { + return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS +} + +// RelayDirection indicates the direction of a relayed packet. +type RelayDirection string + +const ( + RelayDirectionClientToBackend RelayDirection = "client_to_backend" + RelayDirectionBackendToClient RelayDirection = "backend_to_client" +) + +// DialContextFunc dials a backend through the WireGuard tunnel. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// dialTimeoutKey is the context key for a per-request dial timeout. +type dialTimeoutKey struct{} + +// WithDialTimeout returns a context carrying a dial timeout that +// DialContext wrappers can use to scope the timeout to just the +// connection establishment phase. +func WithDialTimeout(ctx context.Context, d time.Duration) context.Context { + return context.WithValue(ctx, dialTimeoutKey{}, d) +} + +// DialTimeoutFromContext returns the dial timeout from the context, if set. +func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) { + d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration) + return d, ok && d > 0 +} diff --git a/proxy/internal/types/types_test.go b/proxy/internal/types/types_test.go new file mode 100644 index 000000000..dd9738442 --- /dev/null +++ b/proxy/internal/types/types_test.go @@ -0,0 +1,54 @@ +package types + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestServiceMode_IsL4(t *testing.T) { + tests := []struct { + mode ServiceMode + want bool + }{ + {ServiceModeHTTP, false}, + {ServiceModeTCP, true}, + {ServiceModeUDP, true}, + {ServiceModeTLS, true}, + {ServiceMode("unknown"), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + assert.Equal(t, tt.want, tt.mode.IsL4()) + }) + } +} + +func TestDialTimeoutContext(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 5*time.Second) + d, ok := DialTimeoutFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, 5*time.Second, d) + }) + + t.Run("missing", func(t *testing.T) { + _, ok := DialTimeoutFromContext(context.Background()) + assert.False(t, ok) + }) + + t.Run("zero returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 0) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "zero duration should return ok=false") + }) + + t.Run("negative returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), -1*time.Second) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "negative duration should return ok=false") + }) +} diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go new file mode 100644 index 000000000..f2f58e858 --- /dev/null +++ b/proxy/internal/udp/relay.go @@ -0,0 +1,496 @@ +package udp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +const ( + // DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup. + DefaultSessionTTL = 30 * time.Second + // cleanupInterval is how often the cleaner goroutine runs. + cleanupInterval = time.Minute + // maxPacketSize is the maximum UDP packet size we'll handle. + maxPacketSize = 65535 + // DefaultMaxSessions is the default cap on concurrent UDP sessions per relay. + DefaultMaxSessions = 1024 + // sessionCreateRate limits new session creation per second. + sessionCreateRate = 50 + // sessionCreateBurst is the burst allowance for session creation. + sessionCreateBurst = 100 + // defaultDialTimeout is the fallback dial timeout for backend connections. + defaultDialTimeout = 30 * time.Second +) + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// SessionObserver receives callbacks for UDP session lifecycle events. +// All methods must be safe for concurrent use. +type SessionObserver interface { + UDPSessionStarted(accountID types.AccountID) + UDPSessionEnded(accountID types.AccountID) + UDPSessionDialError(accountID types.AccountID) + UDPSessionRejected(accountID types.AccountID) + UDPPacketRelayed(direction types.RelayDirection, bytes int) +} + +// clientAddr is a typed key for UDP session lookups. +type clientAddr string + +// Relay listens for incoming UDP packets on a dedicated port and +// maintains per-client sessions that relay packets to a backend +// through the WireGuard tunnel. +type Relay struct { + logger *log.Entry + listener net.PacketConn + target string + domain string + accountID types.AccountID + serviceID types.ServiceID + dialFunc types.DialContextFunc + dialTimeout time.Duration + sessionTTL time.Duration + maxSessions int + + mu sync.RWMutex + sessions map[clientAddr]*session + + bufPool sync.Pool + sessLimiter *rate.Limiter + sessWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + observer SessionObserver + accessLog l4Logger +} + +type session struct { + backend net.Conn + addr net.Addr + createdAt time.Time + // lastSeen stores the last activity timestamp as unix nanoseconds. + lastSeen atomic.Int64 + cancel context.CancelFunc + // bytesIn tracks total bytes received from the client. + bytesIn atomic.Int64 + // bytesOut tracks total bytes sent back to the client. + bytesOut atomic.Int64 +} + +func (s *session) updateLastSeen() { + s.lastSeen.Store(time.Now().UnixNano()) +} + +func (s *session) idleDuration() time.Duration { + return time.Since(time.Unix(0, s.lastSeen.Load())) +} + +// RelayConfig holds the configuration for a UDP relay. +type RelayConfig struct { + Logger *log.Entry + Listener net.PacketConn + Target string + Domain string + AccountID types.AccountID + ServiceID types.ServiceID + DialFunc types.DialContextFunc + DialTimeout time.Duration + SessionTTL time.Duration + MaxSessions int + AccessLog l4Logger +} + +// New creates a UDP relay for the given listener and backend target. +// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions. +// DialTimeout controls how long to wait for backend connections; use 0 for default. +// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL. +func New(parentCtx context.Context, cfg RelayConfig) *Relay { + maxSessions := cfg.MaxSessions + dialTimeout := cfg.DialTimeout + sessionTTL := cfg.SessionTTL + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + if sessionTTL <= 0 { + sessionTTL = DefaultSessionTTL + } + ctx, cancel := context.WithCancel(parentCtx) + return &Relay{ + logger: cfg.Logger, + listener: cfg.Listener, + target: cfg.Target, + domain: cfg.Domain, + accountID: cfg.AccountID, + serviceID: cfg.ServiceID, + accessLog: cfg.AccessLog, + dialFunc: cfg.DialFunc, + dialTimeout: dialTimeout, + sessionTTL: sessionTTL, + maxSessions: maxSessions, + sessions: make(map[clientAddr]*session), + bufPool: sync.Pool{ + New: func() any { + buf := make([]byte, maxPacketSize) + return &buf + }, + }, + sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst), + ctx: ctx, + cancel: cancel, + } +} + +// ServiceID returns the service ID associated with this relay. +func (r *Relay) ServiceID() types.ServiceID { + return r.serviceID +} + +// SetObserver sets the session lifecycle observer. Must be called before Serve. +func (r *Relay) SetObserver(obs SessionObserver) { + r.observer = obs +} + +// Serve starts the relay loop. It blocks until the context is canceled +// or the listener is closed. +func (r *Relay) Serve() { + go r.cleanupLoop() + + for { + bufp := r.bufPool.Get().(*[]byte) + buf := *bufp + + n, addr, err := r.listener.ReadFrom(buf) + if err != nil { + r.bufPool.Put(bufp) + if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + return + } + r.logger.Debugf("UDP read: %v", err) + continue + } + + data := buf[:n] + sess, err := r.getOrCreateSession(addr) + if err != nil { + r.bufPool.Put(bufp) + r.logger.Debugf("create UDP session for %s: %v", addr, err) + continue + } + + sess.updateLastSeen() + + nw, err := sess.backend.Write(data) + if err != nil { + r.bufPool.Put(bufp) + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to backend for %s: %v", addr, err) + } + r.removeSession(sess) + continue + } + sess.bytesIn.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) + } + r.bufPool.Put(bufp) + } +} + +// getOrCreateSession returns an existing session or creates a new one. +func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { + key := clientAddr(addr.String()) + + r.mu.RLock() + sess, ok := r.sessions[key] + r.mu.RUnlock() + if ok && sess != nil { + return sess, nil + } + + // Check before taking the write lock: if the relay is shutting down, + // don't create new sessions. This prevents orphaned goroutines when + // Serve() processes a packet that was already read before Close(). + if r.ctx.Err() != nil { + return nil, r.ctx.Err() + } + + r.mu.Lock() + + if sess, ok = r.sessions[key]; ok && sess != nil { + r.mu.Unlock() + return sess, nil + } + if ok { + // Another goroutine is dialing for this key, skip. + r.mu.Unlock() + return nil, fmt.Errorf("session dial in progress for %s", key) + } + + if len(r.sessions) >= r.maxSessions { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) + } + + if !r.sessLimiter.Allow() { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session creation rate limited") + } + + // Reserve the slot with a nil session so concurrent callers for the same + // key see it exists and wait. Release the lock before dialing. + r.sessions[key] = nil + r.mu.Unlock() + + dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout) + backend, err := r.dialFunc(dialCtx, "udp", r.target) + dialCancel() + if err != nil { + r.mu.Lock() + delete(r.sessions, key) + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionDialError(r.accountID) + } + return nil, fmt.Errorf("dial backend %s: %w", r.target, err) + } + + sessCtx, sessCancel := context.WithCancel(r.ctx) + sess = &session{ + backend: backend, + addr: addr, + createdAt: time.Now(), + cancel: sessCancel, + } + sess.updateLastSeen() + + r.mu.Lock() + r.sessions[key] = sess + r.mu.Unlock() + + if r.observer != nil { + r.observer.UDPSessionStarted(r.accountID) + } + + r.sessWg.Go(func() { + r.relayBackendToClient(sessCtx, sess) + }) + + r.logger.Debugf("UDP session created for %s", addr) + return sess, nil +} + +// relayBackendToClient reads packets from the backend and writes them +// back to the client through the public-facing listener. +func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { + bufp := r.bufPool.Get().(*[]byte) + defer r.bufPool.Put(bufp) + defer r.removeSession(sess) + + for ctx.Err() == nil { + data, ok := r.readBackendPacket(sess, *bufp) + if !ok { + return + } + if data == nil { + continue + } + + sess.updateLastSeen() + + nw, err := r.listener.WriteTo(data, sess.addr) + if err != nil { + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to client %s: %v", sess.addr, err) + } + return + } + sess.bytesOut.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) + } + } +} + +// readBackendPacket reads one packet from the backend with an idle deadline. +// Returns (data, true) on success, (nil, true) on idle timeout that should +// retry, or (nil, false) when the session should be torn down. +func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) { + if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil { + r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err) + return nil, false + } + + n, err := sess.backend.Read(buf) + if err != nil { + if netutil.IsTimeout(err) { + if sess.idleDuration() > r.sessionTTL { + return nil, false + } + return nil, true + } + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err) + } + return nil, false + } + + return buf[:n], true +} + +// cleanupLoop periodically removes idle sessions. +func (r *Relay) cleanupLoop() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.cleanupIdleSessions() + } + } +} + +// cleanupIdleSessions closes sessions that have been idle for too long. +func (r *Relay) cleanupIdleSessions() { + var expired []*session + + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + continue + } + idle := sess.idleDuration() + if idle > r.sessionTTL { + r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load()) + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close idle session %s backend: %v", sess.addr, err) + } + expired = append(expired, sess) + } + } + r.mu.Unlock() + + for _, sess := range expired { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// removeSession removes a session from the map if it still matches the +// given pointer. This is safe to call concurrently with cleanupIdleSessions +// because the identity check prevents double-close when both paths race. +func (r *Relay) removeSession(sess *session) { + r.mu.Lock() + key := clientAddr(sess.addr.String()) + removed := r.sessions[key] == sess + if removed { + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + } + r.mu.Unlock() + + if removed { + r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// logSessionEnd sends an access log entry for a completed UDP session. +func (r *Relay) logSessionEnd(sess *session) { + if r.accessLog == nil { + return + } + + var sourceIP netip.Addr + if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + + r.accessLog.LogL4(accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: sourceIP, + DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(), + BytesUpload: sess.bytesIn.Load(), + BytesDownload: sess.bytesOut.Load(), + }) +} + +// Close stops the relay, waits for all session goroutines to exit, +// and cleans up remaining sessions. +func (r *Relay) Close() { + r.cancel() + if err := r.listener.Close(); err != nil { + r.logger.Debugf("close UDP listener: %v", err) + } + + var closedSessions []*session + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + delete(r.sessions, key) + continue + } + r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + delete(r.sessions, key) + closedSessions = append(closedSessions, sess) + } + r.mu.Unlock() + + for _, sess := range closedSessions { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } + + r.sessWg.Wait() +} diff --git a/proxy/internal/udp/relay_test.go b/proxy/internal/udp/relay_test.go new file mode 100644 index 000000000..a1e91b290 --- /dev/null +++ b/proxy/internal/udp/relay_test.go @@ -0,0 +1,493 @@ +package udp + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRelay_BasicPacketExchange(t *testing.T) { + // Set up a UDP backend that echoes packets. + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + // Set up the relay's public-facing listener. + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + backendAddr := backend.LocalAddr().String() + + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a client and send a packet to the relay. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + testData := []byte("hello UDP relay") + _, err = client.Write(testData) + require.NoError(t, err) + + // Read the echoed response. + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed packet") +} + +func TestRelay_MultipleClients(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Two clients, each should get their own session. + for i, msg := range []string{"client-1", "client-2"} { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte(msg)) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err, "client %d read", i) + assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i) + } + + // Verify two sessions were created. + relay.mu.RLock() + sessionCount := len(relay.sessions) + relay.mu.RUnlock() + assert.Equal(t, 2, sessionCount, "should have two sessions") +} + +func TestRelay_Close(t *testing.T) { + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc}) + + done := make(chan struct{}) + go func() { + relay.Serve() + close(done) + }() + + relay.Close() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Serve did not return after Close") + } +} + +func TestRelay_SessionCleanup(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a session. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + + // Verify session exists. + relay.mu.RLock() + assert.Equal(t, 1, len(relay.sessions)) + relay.mu.RUnlock() + + // Make session appear idle by setting lastSeen to the past. + relay.mu.Lock() + for _, sess := range relay.sessions { + sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano()) + } + relay.mu.Unlock() + + // Trigger cleanup manually. + relay.cleanupIdleSessions() + + relay.mu.RLock() + assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up") + relay.mu.RUnlock() +} + +// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new +// one on the same port works cleanly (simulates port mapping modify cycle). +func TestRelay_CloseAndRecreate(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // First relay. + ln1, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay1.Serve() + + client1, err := net.Dial("udp", ln1.LocalAddr().String()) + require.NoError(t, err) + _, err = client1.Write([]byte("relay1")) + require.NoError(t, err) + require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + n, err := client1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay1", string(buf[:n])) + client1.Close() + + // Close first relay. + relay1.Close() + + // Second relay on same port. + port := ln1.LocalAddr().(*net.UDPAddr).Port + ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + + relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay2.Serve() + defer relay2.Close() + + client2, err := net.Dial("udp", ln2.LocalAddr().String()) + require.NoError(t, err) + defer client2.Close() + _, err = client2.Write([]byte("relay2")) + require.NoError(t, err) + require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second))) + n, err = client2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port") +} + +func TestRelay_SessionLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // Create a relay with a max of 2 sessions. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2}) + go relay.Serve() + defer relay.Close() + + // Create 2 clients to fill up the session limit. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err, "client %d should get response", i) + } + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions") + relay.mu.RUnlock() + + // Third client should get its packet dropped (session creation fails). + client3, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client3.Close() + + _, err = client3.Write([]byte("should be dropped")) + require.NoError(t, err) + + require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + buf := make([]byte, 1024) + _, err = client3.Read(buf) + assert.Error(t, err, "third client should time out because session was rejected") + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit") + relay.mu.RUnlock() +} + +// testObserver records UDP session lifecycle events for test assertions. +type testObserver struct { + mu sync.Mutex + started int + ended int + rejected int + dialErr int + packets int + bytes int +} + +func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() } +func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() } +func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() } +func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() } +func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) { + o.mu.Lock() + o.packets++ + o.bytes += b + o.mu.Unlock() +} + +func TestRelay_CloseFiresObserverEnded(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc}) + relay.SetObserver(obs) + go relay.Serve() + + // Create two sessions. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + } + + obs.mu.Lock() + assert.Equal(t, 2, obs.started, "should have 2 started events") + obs.mu.Unlock() + + // Close should fire UDPSessionEnded for all remaining sessions. + relay.Close() + + obs.mu.Lock() + assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session") + obs.mu.Unlock() +} + +func TestRelay_SessionRateLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + // High max sessions (1000) but the relay uses a rate limiter internally + // (default: 50/s burst 100). We exhaust the burst by creating sessions + // rapidly, then verify that subsequent creates are rejected. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000}) + relay.SetObserver(obs) + go relay.Serve() + defer relay.Close() + + // Exhaust the burst by calling getOrCreateSession directly with + // synthetic addresses. This is faster than real UDP round-trips. + for i := range sessionCreateBurst + 20 { + addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i} + _, _ = relay.getOrCreateSession(addr) + } + + obs.mu.Lock() + rejected := obs.rejected + obs.mu.Unlock() + + assert.Greater(t, rejected, 0, "some sessions should be rate-limited") +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 6a0ecce30..ebecfc6f6 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -243,6 +243,10 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return nil +} + // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store @@ -505,15 +509,15 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T nil, "", 0, - mapping.GetAccountId(), - mapping.GetId(), + proxytypes.AccountID(mapping.GetAccountId()), + proxytypes.ServiceID(mapping.GetId()), ) require.NoError(t, err) // Apply to real proxy (idempotent) proxyHandler.AddMapping(proxy.Mapping{ Host: mapping.GetDomain(), - ID: mapping.GetId(), + ID: proxytypes.ServiceID(mapping.GetId()), AccountID: proxytypes.AccountID(mapping.GetAccountId()), }) } diff --git a/proxy/server.go b/proxy/server.go index 123b14648..649d49c9a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -30,6 +30,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/sdk/metric" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -46,15 +47,26 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" + udprelay "github.com/netbirdio/netbird/proxy/internal/udp" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) + +// portRouter bundles a per-port Router with its listener and cancel func. +type portRouter struct { + router *nbtcp.Router + listener net.Listener + cancel context.CancelFunc +} + type Server struct { mgmtClient proto.ProxyServiceClient proxy *proxy.ReverseProxy @@ -67,12 +79,27 @@ type Server struct { healthServer *health.Server healthChecker *health.Checker meter *proxymetrics.Metrics + accessLog *accesslog.Logger + mainRouter *nbtcp.Router + mainPort uint16 + udpMu sync.Mutex + udpRelays map[types.ServiceID]*udprelay.Relay + udpRelayWg sync.WaitGroup + portMu sync.RWMutex + portRouters map[uint16]*portRouter + svcPorts map[types.ServiceID][]uint16 + lastMappings map[types.ServiceID]*proto.ProxyMapping + portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker + // routerReady is closed once mainRouter is fully initialized. + // The mapping worker waits on this before processing updates. + routerReady chan struct{} + // Mostly used for debugging on management. startTime time.Time @@ -97,6 +124,11 @@ type Server struct { // CertLockMethod controls how ACME certificate locks are coordinated // across replicas. Default: CertLockAuto (detect environment). CertLockMethod acme.CertLockMethod + // WildcardCertDir is an optional directory containing wildcard certificate + // pairs (.crt / .key). Wildcard patterns are extracted from + // the certificates' SAN lists. Matching domains use these static certs + // instead of ACME. + WildcardCertDir string // DebugEndpointEnabled enables the debug HTTP endpoint. DebugEndpointEnabled bool @@ -113,28 +145,36 @@ type Server struct { // When set, forwarding headers from these sources are preserved and // appended to instead of being stripped. TrustedProxies []netip.Prefix - // WireguardPort is the port for the WireGuard interface. Use 0 for a - // random OS-assigned port. A fixed port only works with single-account - // deployments; multiple accounts will fail to bind the same port. - WireguardPort int + // WireguardPort is the port for the NetBird tunnel interface. Use 0 + // for a random OS-assigned port. A fixed port only works with + // single-account deployments; multiple accounts will fail to bind + // the same port. + WireguardPort uint16 // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool // PreSharedKey used for tunnel between proxy and peers (set globally not per account) PreSharedKey string + // SupportsCustomPorts indicates whether the proxy can bind arbitrary + // ports for TCP/UDP/TLS services. + SupportsCustomPorts bool + // DefaultDialTimeout is the default timeout for establishing backend + // connections when no per-service timeout is configured. Zero means + // each transport uses its own hardcoded default (typically 30s). + DefaultDialTimeout time.Duration } -// NotifyStatus sends a status update to management about tunnel connectivity -func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error { +// NotifyStatus sends a status update to management about tunnel connectivity. +func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED if connected { status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: status, CertificateIssued: false, }) @@ -142,10 +182,10 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain } // NotifyCertificateIssued sends a notification to management that a certificate was issued -func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error { +func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error { _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, }) @@ -154,6 +194,11 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() + s.routerReady = make(chan struct{}) + s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) + s.portRouters = make(map[uint16]*portRouter) + s.svcPorts = make(map[types.ServiceID][]uint16) + s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) exporter, err := prometheus.New() if err != nil { @@ -179,7 +224,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) - go s.newManagementMappingWorker(ctx, s.mgmtClient) + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() + go s.newManagementMappingWorker(runCtx, s.mgmtClient) // Initialize the netbird client, this is required to build peer connections // to proxy over. @@ -201,7 +248,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) + s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.healthChecker = health.NewChecker(s.Logger, s.netbird) @@ -215,18 +262,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { handler := http.Handler(s.proxy) handler = s.auth.Protect(handler) handler = web.AssetHandler(handler) - handler = accessLog.Middleware(handler) + handler = s.accessLog.Middleware(handler) handler = s.meter.Middleware(handler) handler = s.hijackTracker.Middleware(handler) - // Start the reverse proxy HTTPS server. - s.https = &http.Server{ - Addr: addr, - Handler: handler, - TLSConfig: tlsConfig, - ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), - } - + // Start a raw TCP listener; the SNI router peeks at ClientHello + // and routes to either the HTTP handler or a TCP relay. lc := net.ListenConfig{} ln, err := lc.Listen(ctx, "tcp", addr) if err != nil { @@ -235,11 +276,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } + s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid + + // Set up the SNI router for TCP/HTTP multiplexing on the main port. + s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) + s.mainRouter.SetObserver(s.meter) + s.mainRouter.SetAccessLogger(s.accessLog) + close(s.routerReady) + + // The HTTP server uses the chanListener fed by the SNI router. + s.https = &http.Server{ + Addr: addr, + Handler: handler, + TLSConfig: tlsConfig, + ReadHeaderTimeout: httpReadHeaderTimeout, + IdleTimeout: httpIdleTimeout, + ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), + } httpsErr := make(chan error, 1) go func() { - s.Logger.Debugf("starting reverse proxy server on %s", addr) - httpsErr <- s.https.ServeTLS(ln, "", "") + s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") + httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") + }() + + routerErr := make(chan error, 1) + go func() { + s.Logger.Debugf("starting SNI router on %s", addr) + routerErr <- s.mainRouter.Serve(runCtx, ln) }() select { @@ -249,6 +313,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return fmt.Errorf("https server: %w", err) } return nil + case err := <-routerErr: + s.shutdownServices() + if err != nil { + return fmt.Errorf("SNI router: %w", err) + } + return nil case <-ctx.Done(): s.gracefulShutdown() return nil @@ -376,6 +446,13 @@ const ( // shutdownServiceTimeout is the maximum time to wait for auxiliary // services (health probe, debug endpoint, ACME) to shut down. shutdownServiceTimeout = 5 * time.Second + + // httpReadHeaderTimeout limits how long the server waits to read + // request headers after accepting a connection. Prevents slowloris. + httpReadHeaderTimeout = 10 * time.Second + // httpIdleTimeout limits how long an idle keep-alive connection + // stays open before the server closes it. + httpIdleTimeout = 120 * time.Second ) func (s *Server) dialManagement() (*grpc.ClientConn, error) { @@ -437,7 +514,20 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter) + var err error + s.acme, err = acme.NewManager(acme.ManagerConfig{ + CertDir: s.CertificateDirectory, + ACMEURL: s.ACMEDirectory, + EABKID: s.ACMEEABKID, + EABHMACKey: s.ACMEEABHMACKey, + LockMethod: s.CertLockMethod, + WildcardDir: s.WildcardCertDir, + }, s, s.Logger, s.meter) + if err != nil { + return nil, fmt.Errorf("create ACME manager: %w", err) + } + + go s.acme.WatchWildcards(ctx) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{ @@ -453,6 +543,10 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { } tlsConfig = s.acme.TLSConfig() + // autocert.Manager.TLSConfig() wires its own GetCertificate, which + // bypasses our override that checks wildcards first. + tlsConfig.GetCertificate = s.acme.GetCertificate + // ServerName needs to be set to allow for ACME to work correctly // when using CNAME URLs to access the proxy. tlsConfig.ServerName = s.ProxyURL @@ -496,6 +590,9 @@ func (s *Server) gracefulShutdown() { s.Logger.Infof("closed %d hijacked connection(s)", n) } + // Drain all router relay connections (main + per-port) in parallel. + s.drainAllRouters(shutdownDrainTimeout) + // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") @@ -503,6 +600,34 @@ func (s *Server) gracefulShutdown() { // shutdownServices stops all background services concurrently and waits for // them to finish. +// drainAllRouters drains active relay connections on the main router and +// all per-port routers in parallel, up to the given timeout. +func (s *Server) drainAllRouters(timeout time.Duration) { + var wg sync.WaitGroup + + drain := func(name string, router *nbtcp.Router) { + wg.Add(1) + go func() { + defer wg.Done() + if ok := router.Drain(timeout); !ok { + s.Logger.Warnf("timed out draining %s relay connections", name) + } + }() + } + + if s.mainRouter != nil { + drain("main router", s.mainRouter) + } + + s.portMu.RLock() + for port, pr := range s.portRouters { + drain(fmt.Sprintf("port %d", port), pr.router) + } + s.portMu.RUnlock() + + wg.Wait() +} + func (s *Server) shutdownServices() { var wg sync.WaitGroup @@ -540,9 +665,165 @@ func (s *Server) shutdownServices() { }() } + // Close all UDP relays and wait for their goroutines to exit. + s.udpMu.Lock() + for id, relay := range s.udpRelays { + relay.Close() + delete(s.udpRelays, id) + } + s.udpMu.Unlock() + s.udpRelayWg.Wait() + + // Close all per-port routers. + s.portMu.Lock() + for port, pr := range s.portRouters { + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + } + maps.Clear(s.svcPorts) + maps.Clear(s.lastMappings) + s.portMu.Unlock() + + // Wait for per-port router serve goroutines to exit. + s.portRouterWg.Wait() + wg.Wait() } +// resolveDialFunc returns a DialContextFunc that dials through the +// NetBird tunnel for the given account. +func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) { + client, ok := s.netbird.GetClient(accountID) + if !ok { + return nil, fmt.Errorf("no client for account %s", accountID) + } + return client.DialContext, nil +} + +// notifyError reports a resource error back to management so it can be +// surfaced to the user (e.g. port bind failure, dialer resolution error). +func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { + s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err) +} + +// sendStatusUpdate sends a status update for a service to management. +func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) { + req := &proto.SendStatusUpdateRequest{ + ServiceId: string(serviceID), + AccountId: string(accountID), + Status: st, + } + if err != nil { + msg := err.Error() + req.ErrorMessage = &msg + } + if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil { + s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr) + } +} + +// routerForPort returns the router that handles the given listen port. If port +// is 0 or matches the main listener port, the main router is returned. +// Otherwise a new per-port router is created and started. +func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) { + if port == 0 || port == s.mainPort { + return s.mainRouter, nil + } + return s.getOrCreatePortRouter(ctx, port) +} + +// routerForPortExisting returns the router for the given port without creating +// one. Returns the main router for port 0 / mainPort, or nil if no per-port +// router exists. +func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router { + if port == 0 || port == s.mainPort { + return s.mainRouter + } + s.portMu.RLock() + pr := s.portRouters[port] + s.portMu.RUnlock() + if pr != nil { + return pr.router + } + return nil +} + +// getOrCreatePortRouter returns an existing per-port router or creates one +// with a new TCP listener and starts serving. +func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) { + s.portMu.Lock() + defer s.portMu.Unlock() + + if pr, ok := s.portRouters[port]; ok { + return pr.router, nil + } + + listenAddr := fmt.Sprintf(":%d", port) + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err) + } + if s.ProxyProtocol { + ln = s.wrapProxyProtocol(ln) + } + + router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) + router.SetObserver(s.meter) + router.SetAccessLogger(s.accessLog) + portCtx, cancel := context.WithCancel(ctx) + + s.portRouters[port] = &portRouter{ + router: router, + listener: ln, + cancel: cancel, + } + + s.portRouterWg.Add(1) + go func() { + defer s.portRouterWg.Done() + if err := router.Serve(portCtx, ln); err != nil { + s.Logger.Debugf("port %d router stopped: %v", port, err) + } + }() + + s.Logger.Debugf("started per-port router on %s", listenAddr) + return router, nil +} + +// cleanupPortIfEmpty tears down a per-port router if it has no remaining +// routes or fallback. The main port is never cleaned up. Active relay +// connections are drained before the listener is closed. +func (s *Server) cleanupPortIfEmpty(port uint16) { + if port == 0 || port == s.mainPort { + return + } + + s.portMu.Lock() + pr, ok := s.portRouters[port] + if !ok || !pr.router.IsEmpty() { + s.portMu.Unlock() + return + } + + // Cancel and close the listener while holding the lock so that + // getOrCreatePortRouter sees the entry is gone before we drain. + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + s.portMu.Unlock() + + // Drain active relay connections outside the lock. + if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok { + s.Logger.Warnf("timed out draining relay connections on port %d", port) + } + s.Logger.Debugf("cleaned up empty per-port router on port %d", port) +} + func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { bo := &backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, @@ -568,6 +849,9 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Version: s.Version, StartedAt: timestamppb.New(s.startTime), Address: s.ProxyURL, + Capabilities: &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + }, }) if err != nil { return fmt.Errorf("create mapping stream: %w", err) @@ -604,6 +888,12 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + for { // Check for context completion to gracefully shutdown. select { @@ -640,25 +930,28 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), - "path": mapping.GetPath(), + "mode": mapping.GetMode(), + "port": mapping.GetListenPort(), "id": mapping.GetId(), }).Debug("Processing mapping update") switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { - // TODO: Retry this? Or maybe notify the management server that this mapping has failed? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("Error adding new mapping, ignoring this mapping and continuing processing") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: - if err := s.updateMapping(ctx, mapping); err != nil { + if err := s.modifyMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), - }).Errorf("failed to update mapping: %v", err) + "error": err, + }).Error("failed to modify mapping") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: s.removeMapping(ctx, mapping) @@ -666,26 +959,331 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } +// addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - serviceID := mapping.GetId() + svcID := types.ServiceID(mapping.GetId()) authToken := mapping.GetAuthToken() - if err := s.netbird.AddPeer(ctx, accountID, d, authToken, serviceID); err != nil { - return fmt.Errorf("create peer for domain %q: %w", d, err) - } - if s.acme != nil { - s.acme.AddDomain(d, string(accountID), serviceID) + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil { + return fmt.Errorf("create peer for service %s: %w", svcID, err) } - // Pass the mapping through to the update function to avoid duplicating the - // setup, currently update is simply a subset of this function, so this - // separation makes sense...to me at least. + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil { + s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure") + } + return err + } + s.storeMapping(mapping) + return nil +} + +// modifyMapping updates a service mapping in place without tearing down the +// NetBird peer. It cleans up old routes using the previously stored mapping +// state and re-applies them from the new mapping. +func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) + } + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + return err + } + s.storeMapping(mapping) + return nil +} + +// setupMappingRoutes configures the appropriate routes or relays for the given +// service mapping based on its mode. The NetBird peer must already exist. +func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP: + return s.setupTCPMapping(ctx, mapping) + case types.ServiceModeUDP: + return s.setupUDPMapping(ctx, mapping) + case types.ServiceModeTLS: + return s.setupTLSMapping(ctx, mapping) + default: + return s.setupHTTPMapping(ctx, mapping) + } +} + +// setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes. +func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + + if len(mapping.GetPath()) == 0 { + return nil + } + + var wildcardHit bool + if s.acme != nil { + wildcardHit = s.acme.AddDomain(d, accountID, svcID) + } + s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteHTTP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + }) if err := s.updateMapping(ctx, mapping); err != nil { - s.removeMapping(ctx, mapping) return fmt.Errorf("update mapping for domain %q: %w", d, err) } + + if wildcardHit { + if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil { + s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) + } + } + + return nil +} + +// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port. +func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TCP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TCP service %s", svcID) + } + + if s.WireguardPort != 0 && port == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", port) + } + + router, err := s.routerForPort(ctx, port) + if err != nil { + return fmt.Errorf("router for TCP port %d: %w", port, err) + } + + router.SetFallback(nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTCP, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{port} + s.portMu.Unlock() + + s.meter.L4ServiceAdded(types.ServiceModeTCP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupUDPMapping starts a UDP relay on the listen port. +func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("UDP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for UDP service %s", svcID) + } + + if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { + return fmt.Errorf("UDP relay for service %s: %w", svcID, err) + } + + s.meter.L4ServiceAdded(types.ServiceModeUDP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupTLSMapping configures a TLS SNI-routed passthrough on the listen port. +func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + tlsPort, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TLS service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TLS service %s", svcID) + } + + if s.WireguardPort != 0 && tlsPort == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", tlsPort) + } + + router, err := s.routerForPort(ctx, tlsPort) + if err != nil { + return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) + } + + router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTLS, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + if tlsPort != s.mainPort { + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{tlsPort} + s.portMu.Unlock() + } + + s.Logger.WithFields(log.Fields{ + "domain": mapping.GetDomain(), + "target": targetAddr, + "port": tlsPort, + "service": svcID, + }).Info("TLS passthrough mapping added") + + s.meter.L4ServiceAdded(types.ServiceModeTLS) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// serviceKeyForMapping returns the appropriate ServiceKey for a mapping. +// TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key. +func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP, types.ServiceModeUDP: + return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId())) + default: + return roundtrip.DomainServiceKey(mapping.GetDomain()) + } +} + +// l4TargetAddress extracts and validates the target address from a mapping's +// first path entry. Returns empty string if no paths exist or the address is +// not a valid host:port. +func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string { + paths := mapping.GetPath() + if len(paths) == 0 { + return "" + } + target := paths[0].GetTarget() + if _, _, err := net.SplitHostPort(target); err != nil { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "target": target, + }).Warnf("invalid L4 target address: %v", err) + return "" + } + return target +} + +// l4ProxyProtocol returns whether the first target has PROXY protocol enabled. +func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { + paths := mapping.GetPath() + if len(paths) == 0 { + return false + } + return paths[0].GetOptions().GetProxyProtocol() +} + +// l4DialTimeout returns the dial timeout from the first target's options, +// falling back to the server's DefaultDialTimeout. +func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { + return d.AsDuration() + } + } + return s.DefaultDialTimeout +} + +// l4SessionIdleTimeout returns the configured session idle timeout from the +// mapping options, or 0 to use the relay's default. +func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil { + return d.AsDuration() + } + } + return 0 +} + +// addUDPRelay starts a UDP relay on the specified listen port. +func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + if s.WireguardPort != 0 && listenPort == s.WireguardPort { + return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort) + } + + // Close existing relay if present (idempotent re-add). + s.removeUDPRelay(svcID) + + listenAddr := fmt.Sprintf(":%d", listenPort) + + listener, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return fmt.Errorf("listen UDP on %s: %w", listenAddr, err) + } + + dialFn, err := s.resolveDialFunc(accountID) + if err != nil { + _ = listener.Close() + return fmt.Errorf("resolve dialer for UDP: %w", err) + } + + entry := s.Logger.WithFields(log.Fields{ + "target": targetAddress, + "listen_port": listenPort, + "service_id": svcID, + }) + + relay := udprelay.New(ctx, udprelay.RelayConfig{ + Logger: entry, + Listener: listener, + Target: targetAddress, + Domain: mapping.GetDomain(), + AccountID: accountID, + ServiceID: svcID, + DialFunc: dialFn, + DialTimeout: s.l4DialTimeout(mapping), + SessionTTL: l4SessionIdleTimeout(mapping), + AccessLog: s.accessLog, + }) + relay.SetObserver(s.meter) + + s.udpMu.Lock() + s.udpRelays[svcID] = relay + s.udpMu.Unlock() + + s.udpRelayWg.Go(relay.Serve) + entry.Info("UDP relay added") return nil } @@ -695,50 +1293,142 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) // the auth and proxy mappings. // Note: this does require the management server to always send a // full mapping rather than deltas during a modification. + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + var schemes []auth.Scheme if mapping.GetAuth().GetPassword() { - schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetPin() { - schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetOidc() { - schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), s.ForwardedProto)) + schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) } maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } - s.proxy.AddMapping(s.protoToMapping(mapping)) - s.meter.AddMapping(s.protoToMapping(mapping)) + m := s.protoToMapping(ctx, mapping) + s.proxy.AddMapping(m) + s.meter.AddMapping(m) return nil } +// removeMapping tears down routes/relays and the NetBird peer for a service. +// Uses the stored mapping state when available to ensure all previously +// configured routes are cleaned up. func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil { + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil { s.Logger.WithFields(log.Fields{ "account_id": accountID, - "domain": d, + "service_id": mapping.GetId(), "error": err, - }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") + }).Error("failed to remove NetBird peer, continuing cleanup") } - if s.acme != nil { - s.acme.RemoveDomain(d) + + if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) } - s.auth.RemoveDomain(mapping.GetDomain()) - s.proxy.RemoveMapping(s.protoToMapping(mapping)) - s.meter.RemoveMapping(s.protoToMapping(mapping)) } -func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { +// cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a +// service without touching the NetBird peer. This is used for both full +// removal and in-place modification of mappings. +func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { + svcID := types.ServiceID(mapping.GetId()) + host := mapping.GetDomain() + + // HTTP/TLS cleanup (only relevant when a domain is set). + if host != "" { + d := domain.Domain(host) + if s.acme != nil { + s.acme.RemoveDomain(d) + } + s.auth.RemoveDomain(host) + if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) { + s.meter.RemoveMapping(proxy.Mapping{Host: host}) + } + // Close hijacked connections (WebSocket) for this domain. + if n := s.hijackTracker.CloseByHost(host); n > 0 { + s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host) + } + // Remove SNI route from the main router (covers both HTTP and main-port TLS). + s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) + } + + // Extract and delete tracked custom-port entries atomically. + s.portMu.Lock() + entries := s.svcPorts[svcID] + delete(s.svcPorts, svcID) + s.portMu.Unlock() + + for _, entry := range entries { + if router := s.routerForPortExisting(entry); router != nil { + if host != "" { + router.RemoveRoute(nbtcp.SNIHost(host), svcID) + } else { + router.RemoveFallback(svcID) + } + } + s.cleanupPortIfEmpty(entry) + } + + // UDP relay cleanup (idempotent). + s.removeUDPRelay(svcID) + +} + +// removeUDPRelay stops and removes a UDP relay by service ID. +func (s *Server) removeUDPRelay(svcID types.ServiceID) { + s.udpMu.Lock() + relay, ok := s.udpRelays[svcID] + if ok { + delete(s.udpRelays, svcID) + } + s.udpMu.Unlock() + + if ok { + relay.Close() + s.Logger.WithField("service_id", svcID).Info("UDP relay removed") + } +} + +func (s *Server) storeMapping(mapping *proto.ProxyMapping) { + s.portMu.Lock() + s.lastMappings[types.ServiceID(mapping.GetId())] = mapping + s.portMu.Unlock() +} + +func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.RLock() + m := s.lastMappings[svcID] + s.portMu.RUnlock() + return m +} + +func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.Lock() + m := s.lastMappings[svcID] + delete(s.lastMappings, svcID) + s.portMu.Unlock() + return m +} + +func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping { paths := make(map[string]*proxy.PathTarget) for _, pathMapping := range mapping.GetPath() { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { - // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), @@ -746,6 +1436,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { "path": pathMapping.GetPath(), "target": pathMapping.GetTarget(), }).WithError(err).Error("failed to parse target URL for path, skipping") + s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err)) continue } @@ -758,10 +1449,13 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { pt.RequestTimeout = d.AsDuration() } } + if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 { + pt.RequestTimeout = s.DefaultDialTimeout + } paths[pathMapping.GetPath()] = pt } return proxy.Mapping{ - ID: mapping.GetId(), + ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 9505b3fdf..333f0bf00 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -56,12 +56,14 @@ type ExposeRequest struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ExposeResponse struct { - ServiceName string - Domain string - ServiceURL string + ServiceName string + Domain string + ServiceURL string + PortAutoAssigned bool } // NewClient creates a new client to Management service @@ -790,9 +792,10 @@ func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { func fromProtoExposeResponse(resp *proto.ExposeServiceResponse) *ExposeResponse { return &ExposeResponse{ - ServiceName: resp.ServiceName, - Domain: resp.Domain, - ServiceURL: resp.ServiceUrl, + ServiceName: resp.ServiceName, + Domain: resp.Domain, + ServiceURL: resp.ServiceUrl, + PortAutoAssigned: resp.PortAutoAssigned, } } @@ -808,6 +811,8 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest protocol = proto.ExposeProtocol_EXPOSE_TCP case int(proto.ExposeProtocol_EXPOSE_UDP): protocol = proto.ExposeProtocol_EXPOSE_UDP + case int(proto.ExposeProtocol_EXPOSE_TLS): + protocol = proto.ExposeProtocol_EXPOSE_TLS default: return nil, fmt.Errorf("invalid expose protocol: %d", req.Protocol) } @@ -820,6 +825,7 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: uint32(req.ListenPort), }, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 6d64c002b..bdeb5d1b4 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -350,6 +350,10 @@ components: description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") type: string example: "0.51.2" + auto_update_always: + description: When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + type: boolean + example: false embedded_idp_enabled: description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. type: boolean @@ -2835,6 +2839,10 @@ components: format: int64 description: "Bytes downloaded (response body size)" example: 8192 + protocol: + type: string + description: "Protocol type: http, tcp, or udp" + example: "http" required: - id - service_id @@ -2953,6 +2961,20 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only) + port_auto_assigned: + type: boolean + description: Whether the listen port was auto-assigned + readOnly: true proxy_cluster: type: string description: The proxy cluster handling this service (derived from domain) @@ -3019,6 +3041,16 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. targets: type: array items: @@ -3039,8 +3071,6 @@ components: required: - name - domain - - targets - - auth - enabled ServiceTargetOptions: type: object @@ -3064,6 +3094,12 @@ components: additionalProperties: type: string pattern: '^[^\r\n]*$' + proxy_protocol: + type: boolean + description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) + session_idle_timeout: + type: string + description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. ServiceTarget: type: object properties: @@ -3072,21 +3108,23 @@ components: description: Target ID target_type: type: string - description: Target type (e.g., "peer", "resource") - enum: [peer, resource] + description: Target type + enum: [peer, host, domain, subnet] path: type: string - description: URL path prefix for this target + description: URL path prefix for this target (HTTP only) protocol: type: string description: Protocol to use when connecting to the backend - enum: [http, https] + enum: [http, https, tcp, udp] host: type: string description: Backend ip or domain for this target port: type: integer - description: Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + minimum: 1 + maximum: 65535 + description: Backend port for this target enabled: type: boolean description: Whether this target is enabled @@ -3193,6 +3231,9 @@ components: target_cluster: type: string description: The proxy cluster this domain is validated against (only for custom domains) + supports_custom_ports: + type: boolean + description: Whether the cluster supports binding arbitrary TCP/UDP ports required: - id - domain @@ -4405,6 +4446,12 @@ components: requires_authentication: description: Requires authentication content: { } + conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' securitySchemes: BearerAuth: type: http @@ -9904,6 +9951,29 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services @@ -9953,29 +10023,8 @@ paths: "$ref": "#/components/responses/requires_authentication" '403': "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - /api/reverse-proxies/clusters: - get: - summary: List available proxy clusters - description: Returns a list of available proxy clusters with their connection status - tags: [ Services ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - responses: - '200': - description: A JSON Array of proxy clusters - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/ProxyCluster' - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services/{serviceId}: @@ -10045,6 +10094,8 @@ paths: "$ref": "#/components/responses/forbidden" '404': "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" delete: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d485d765c..4dee56c07 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -880,6 +880,30 @@ func (e SentinelOneMatchAttributesNetworkStatus) Valid() bool { } } +// Defines values for ServiceMode. +const ( + ServiceModeHttp ServiceMode = "http" + ServiceModeTcp ServiceMode = "tcp" + ServiceModeTls ServiceMode = "tls" + ServiceModeUdp ServiceMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceMode enum. +func (e ServiceMode) Valid() bool { + switch e { + case ServiceModeHttp: + return true + case ServiceModeTcp: + return true + case ServiceModeTls: + return true + case ServiceModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceMetaStatus. const ( ServiceMetaStatusActive ServiceMetaStatus = "active" @@ -910,10 +934,36 @@ func (e ServiceMetaStatus) Valid() bool { } } +// Defines values for ServiceRequestMode. +const ( + ServiceRequestModeHttp ServiceRequestMode = "http" + ServiceRequestModeTcp ServiceRequestMode = "tcp" + ServiceRequestModeTls ServiceRequestMode = "tls" + ServiceRequestModeUdp ServiceRequestMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceRequestMode enum. +func (e ServiceRequestMode) Valid() bool { + switch e { + case ServiceRequestModeHttp: + return true + case ServiceRequestModeTcp: + return true + case ServiceRequestModeTls: + return true + case ServiceRequestModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceTargetProtocol. const ( ServiceTargetProtocolHttp ServiceTargetProtocol = "http" ServiceTargetProtocolHttps ServiceTargetProtocol = "https" + ServiceTargetProtocolTcp ServiceTargetProtocol = "tcp" + ServiceTargetProtocolUdp ServiceTargetProtocol = "udp" ) // Valid indicates whether the value is a known member of the ServiceTargetProtocol enum. @@ -923,6 +973,10 @@ func (e ServiceTargetProtocol) Valid() bool { return true case ServiceTargetProtocolHttps: return true + case ServiceTargetProtocolTcp: + return true + case ServiceTargetProtocolUdp: + return true default: return false } @@ -930,16 +984,22 @@ func (e ServiceTargetProtocol) Valid() bool { // Defines values for ServiceTargetTargetType. const ( - ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" - ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource" + ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" + ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" + ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" + ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" ) // Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. func (e ServiceTargetTargetType) Valid() bool { switch e { + case ServiceTargetTargetTypeDomain: + return true + case ServiceTargetTargetTypeHost: + return true case ServiceTargetTargetTypePeer: return true - case ServiceTargetTargetTypeResource: + case ServiceTargetTargetTypeSubnet: return true default: return false @@ -1307,6 +1367,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateAlways When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + AutoUpdateAlways *bool `json:"auto_update_always,omitempty"` + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") AutoUpdateVersion *string `json:"auto_update_version,omitempty"` @@ -3327,6 +3390,9 @@ type ProxyAccessLog struct { // Path Path of the request Path string `json:"path"` + // Protocol Protocol type: http, tcp, or udp + Protocol *string `json:"protocol,omitempty"` + // Reason Reason for the request result (e.g., authentication failure) Reason *string `json:"reason,omitempty"` @@ -3391,6 +3457,9 @@ type ReverseProxyDomain struct { // Id Domain ID Id string `json:"id"` + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports + SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + // TargetCluster The proxy cluster this domain is validated against (only for custom domains) TargetCluster *string `json:"target_cluster,omitempty"` @@ -3583,8 +3652,14 @@ type Service struct { Enabled bool `json:"enabled"` // Id Service ID - Id string `json:"id"` - Meta ServiceMeta `json:"meta"` + Id string `json:"id"` + + // ListenPort Port the proxy listens on (L4/TLS only) + ListenPort *int `json:"listen_port,omitempty"` + Meta ServiceMeta `json:"meta"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceMode `json:"mode,omitempty"` // Name Service name Name string `json:"name"` @@ -3592,6 +3667,9 @@ type Service struct { // PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address PassHostHeader *bool `json:"pass_host_header,omitempty"` + // PortAutoAssigned Whether the listen port was auto-assigned + PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"` + // ProxyCluster The proxy cluster handling this service (derived from domain) ProxyCluster *string `json:"proxy_cluster,omitempty"` @@ -3602,6 +3680,9 @@ type Service struct { Targets []ServiceTarget `json:"targets"` } +// ServiceMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceMode string + // ServiceAuthConfig defines model for ServiceAuthConfig. type ServiceAuthConfig struct { BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` @@ -3627,7 +3708,7 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { - Auth ServiceAuthConfig `json:"auth"` + Auth *ServiceAuthConfig `json:"auth,omitempty"` // Domain Domain for the service Domain string `json:"domain"` @@ -3635,6 +3716,12 @@ type ServiceRequest struct { // Enabled Whether the service is enabled Enabled bool `json:"enabled"` + // ListenPort Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + ListenPort *int `json:"listen_port,omitempty"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceRequestMode `json:"mode,omitempty"` + // Name Service name Name string `json:"name"` @@ -3645,9 +3732,12 @@ type ServiceRequest struct { RewriteRedirects *bool `json:"rewrite_redirects,omitempty"` // Targets List of target backends for this service - Targets []ServiceTarget `json:"targets"` + Targets *[]ServiceTarget `json:"targets,omitempty"` } +// ServiceRequestMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceRequestMode string + // ServiceTarget defines model for ServiceTarget. type ServiceTarget struct { // Enabled Whether this target is enabled @@ -3657,10 +3747,10 @@ type ServiceTarget struct { Host *string `json:"host,omitempty"` Options *ServiceTargetOptions `json:"options,omitempty"` - // Path URL path prefix for this target + // Path URL path prefix for this target (HTTP only) Path *string `json:"path,omitempty"` - // Port Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + // Port Backend port for this target Port int `json:"port"` // Protocol Protocol to use when connecting to the backend @@ -3669,14 +3759,14 @@ type ServiceTarget struct { // TargetId Target ID TargetId string `json:"target_id"` - // TargetType Target type (e.g., "peer", "resource") + // TargetType Target type TargetType ServiceTargetTargetType `json:"target_type"` } // ServiceTargetProtocol Protocol to use when connecting to the backend type ServiceTargetProtocol string -// ServiceTargetTargetType Target type (e.g., "peer", "resource") +// ServiceTargetTargetType Target type type ServiceTargetTargetType string // ServiceTargetOptions defines model for ServiceTargetOptions. @@ -3687,9 +3777,15 @@ type ServiceTargetOptions struct { // PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"` + // ProxyProtocol Send PROXY Protocol v2 header to this backend (TCP/TLS only) + ProxyProtocol *bool `json:"proxy_protocol,omitempty"` + // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") RequestTimeout *string `json:"request_timeout,omitempty"` + // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. + SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` + // SkipTlsVerify Skip TLS certificate verification for this backend SkipTlsVerify *bool `json:"skip_tls_verify,omitempty"` } @@ -4214,6 +4310,9 @@ type ZoneRequest struct { Name string `json:"name"` } +// Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. +type Conflict = ErrorResponse + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2c66bb946..c5581296c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -228,6 +228,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -237,12 +238,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -4047,6 +4050,7 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ExposeServiceRequest) Reset() { @@ -4130,14 +4134,22 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` } func (x *ExposeServiceResponse) Reset() { @@ -4193,6 +4205,13 @@ func (x *ExposeServiceResponse) GetDomain() string { return "" } +func (x *ExposeServiceResponse) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type RenewExposeRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -4996,7 +5015,7 @@ var file_management_proto_rawDesc = []byte{ 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0xea, 0x01, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, @@ -5010,15 +5029,20 @@ var file_management_proto_rawDesc = []byte{ 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0x73, - 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, + 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, + 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, + 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, + 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, + 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, + 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, @@ -5039,12 +5063,13 @@ var file_management_proto_rawDesc = []byte{ 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, - 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x53, 0x0a, 0x0e, 0x45, + 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 3667ae27f..9acf7e2b3 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -340,8 +340,8 @@ message PeerConfig { message AutoUpdateSettings { string version = 1; /* - alwaysUpdate = true → Updates happen automatically in the background - alwaysUpdate = false → Updates only happen when triggered by a peer connection + alwaysUpdate = true → Updates are installed automatically in the background + alwaysUpdate = false → Updates require user interaction from the UI */ bool alwaysUpdate = 2; } @@ -652,6 +652,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -662,12 +663,14 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceResponse { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } message RenewExposeRequest { diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 275e8be37..115ac5101 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.33.3 // source: proxy_service.proto package proto @@ -175,22 +175,72 @@ func (ProxyStatus) EnumDescriptor() ([]byte, []int) { return file_proxy_service_proto_rawDescGZIP(), []int{2} } +// ProxyCapabilities describes what a proxy can handle. +type ProxyCapabilities struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` +} + +func (x *ProxyCapabilities) Reset() { + *x = ProxyCapabilities{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProxyCapabilities) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProxyCapabilities) ProtoMessage() {} + +func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProxyCapabilities.ProtoReflect.Descriptor instead. +func (*ProxyCapabilities) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{0} +} + +func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { + if x != nil && x.SupportsCustomPorts != nil { + return *x.SupportsCustomPorts + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -203,7 +253,7 @@ func (x *GetMappingUpdateRequest) String() string { func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -216,7 +266,7 @@ func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateRequest.ProtoReflect.Descriptor instead. func (*GetMappingUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{0} + return file_proxy_service_proto_rawDescGZIP(), []int{1} } func (x *GetMappingUpdateRequest) GetProxyId() string { @@ -247,6 +297,13 @@ func (x *GetMappingUpdateRequest) GetAddress() string { return "" } +func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + // GetMappingUpdateResponse contains zero or more ProxyMappings. // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. @@ -264,7 +321,7 @@ type GetMappingUpdateResponse struct { func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -277,7 +334,7 @@ func (x *GetMappingUpdateResponse) String() string { func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -290,7 +347,7 @@ func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateResponse.ProtoReflect.Descriptor instead. func (*GetMappingUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{1} + return file_proxy_service_proto_rawDescGZIP(), []int{2} } func (x *GetMappingUpdateResponse) GetMapping() []*ProxyMapping { @@ -316,12 +373,16 @@ type PathTargetOptions struct { RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Send PROXY protocol v2 header to this backend. + ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` + // Idle timeout before a UDP session is reaped. + SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` } func (x *PathTargetOptions) Reset() { *x = PathTargetOptions{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -334,7 +395,7 @@ func (x *PathTargetOptions) String() string { func (*PathTargetOptions) ProtoMessage() {} func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -347,7 +408,7 @@ func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { // Deprecated: Use PathTargetOptions.ProtoReflect.Descriptor instead. func (*PathTargetOptions) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{2} + return file_proxy_service_proto_rawDescGZIP(), []int{3} } func (x *PathTargetOptions) GetSkipTlsVerify() bool { @@ -378,6 +439,20 @@ func (x *PathTargetOptions) GetCustomHeaders() map[string]string { return nil } +func (x *PathTargetOptions) GetProxyProtocol() bool { + if x != nil { + return x.ProxyProtocol + } + return false +} + +func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { + if x != nil { + return x.SessionIdleTimeout + } + return nil +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -391,7 +466,7 @@ type PathMapping struct { func (x *PathMapping) Reset() { *x = PathMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -404,7 +479,7 @@ func (x *PathMapping) String() string { func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -417,7 +492,7 @@ func (x *PathMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use PathMapping.ProtoReflect.Descriptor instead. func (*PathMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{3} + return file_proxy_service_proto_rawDescGZIP(), []int{4} } func (x *PathMapping) GetPath() string { @@ -456,7 +531,7 @@ type Authentication struct { func (x *Authentication) Reset() { *x = Authentication{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -469,7 +544,7 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -482,7 +557,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{4} + return file_proxy_service_proto_rawDescGZIP(), []int{5} } func (x *Authentication) GetSessionKey() string { @@ -538,12 +613,16 @@ type ProxyMapping struct { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. RewriteRedirects bool `protobuf:"varint,9,opt,name=rewrite_redirects,json=rewriteRedirects,proto3" json:"rewrite_redirects,omitempty"` + // Service mode: "http", "tcp", "udp", or "tls". + Mode string `protobuf:"bytes,10,opt,name=mode,proto3" json:"mode,omitempty"` + // For L4/TLS: the port the proxy listens on. + ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -556,7 +635,7 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -569,7 +648,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -635,6 +714,20 @@ func (x *ProxyMapping) GetRewriteRedirects() bool { return false } +func (x *ProxyMapping) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *ProxyMapping) GetListenPort() int32 { + if x != nil { + return x.ListenPort + } + return 0 +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { state protoimpl.MessageState @@ -647,7 +740,7 @@ type SendAccessLogRequest struct { func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -660,7 +753,7 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -673,7 +766,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{7} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -693,7 +786,7 @@ type SendAccessLogResponse struct { func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -706,7 +799,7 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -719,7 +812,7 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } type AccessLog struct { @@ -742,12 +835,13 @@ type AccessLog struct { AuthSuccess bool `protobuf:"varint,13,opt,name=auth_success,json=authSuccess,proto3" json:"auth_success,omitempty"` BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` + Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` } func (x *AccessLog) Reset() { *x = AccessLog{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -760,7 +854,7 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -773,7 +867,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -881,6 +975,13 @@ func (x *AccessLog) GetBytesDownload() int64 { return 0 } +func (x *AccessLog) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -898,7 +999,7 @@ type AuthenticateRequest struct { func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -911,7 +1012,7 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -924,7 +1025,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } func (x *AuthenticateRequest) GetId() string { @@ -989,7 +1090,7 @@ type PasswordRequest struct { func (x *PasswordRequest) Reset() { *x = PasswordRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1002,7 +1103,7 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1015,7 +1116,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } func (x *PasswordRequest) GetPassword() string { @@ -1036,7 +1137,7 @@ type PinRequest struct { func (x *PinRequest) Reset() { *x = PinRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1049,7 +1150,7 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1062,7 +1163,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *PinRequest) GetPin() string { @@ -1084,7 +1185,7 @@ type AuthenticateResponse struct { func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1097,7 +1198,7 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1110,7 +1211,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{13} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1143,7 +1244,7 @@ type SendStatusUpdateRequest struct { func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1156,7 +1257,7 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1169,7 +1270,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1217,7 +1318,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1230,7 +1331,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1243,7 +1344,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1263,7 +1364,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1276,7 +1377,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1289,7 +1390,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1340,7 +1441,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1353,7 +1454,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1366,7 +1467,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1396,7 +1497,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1409,7 +1510,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1422,7 +1523,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } func (x *GetOIDCURLRequest) GetId() string { @@ -1457,7 +1558,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1470,7 +1571,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1483,7 +1584,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1505,7 +1606,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1518,7 +1619,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1531,7 +1632,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1562,7 +1663,7 @@ type ValidateSessionResponse struct { func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1575,7 +1676,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1588,7 +1689,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1628,124 +1729,147 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xa3, 0x01, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, - 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, - 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, - 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, - 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xda, 0x02, - 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, - 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x54, 0x6c, 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, - 0x3e, 0x0a, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, - 0x64, 0x65, 0x52, 0x0b, 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, - 0x57, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, - 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, - 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, - 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, - 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, - 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, - 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0xe0, 0x02, 0x0a, 0x0c, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, - 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, - 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, - 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, - 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, - 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, - 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x22, 0x3f, - 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, - 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xea, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, - 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, - 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, - 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, - 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, - 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, - 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, - 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, - 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, - 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, - 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x6f, 0x22, 0x66, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, 0x01, + 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x22, 0xe6, 0x01, 0x0a, 0x17, 0x47, + 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, + 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, + 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, + 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, + 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, + 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, + 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, + 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, 0x61, + 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, + 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x70, + 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, + 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, 0x6f, + 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, + 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, + 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, 0x74, + 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, + 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, + 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, 0x01, + 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, + 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, + 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0x95, 0x03, 0x0a, 0x0c, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, + 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, + 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, + 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, + 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x12, 0x12, 0x0a, + 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6d, 0x6f, 0x64, + 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, + 0x72, 0x74, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, + 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x86, 0x04, 0x0a, + 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, + 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, + 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, + 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, @@ -1907,70 +2031,73 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 22) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 23) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode (ProxyStatus)(0), // 2: management.ProxyStatus - (*GetMappingUpdateRequest)(nil), // 3: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 4: management.GetMappingUpdateResponse - (*PathTargetOptions)(nil), // 5: management.PathTargetOptions - (*PathMapping)(nil), // 6: management.PathMapping - (*Authentication)(nil), // 7: management.Authentication - (*ProxyMapping)(nil), // 8: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 9: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 10: management.SendAccessLogResponse - (*AccessLog)(nil), // 11: management.AccessLog - (*AuthenticateRequest)(nil), // 12: management.AuthenticateRequest - (*PasswordRequest)(nil), // 13: management.PasswordRequest - (*PinRequest)(nil), // 14: management.PinRequest - (*AuthenticateResponse)(nil), // 15: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 16: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 17: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 18: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 19: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 20: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 21: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 22: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 23: management.ValidateSessionResponse - nil, // 24: management.PathTargetOptions.CustomHeadersEntry - (*timestamppb.Timestamp)(nil), // 25: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 26: google.protobuf.Duration + (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 6: management.PathTargetOptions + (*PathMapping)(nil), // 7: management.PathMapping + (*Authentication)(nil), // 8: management.Authentication + (*ProxyMapping)(nil), // 9: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 10: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 11: management.SendAccessLogResponse + (*AccessLog)(nil), // 12: management.AccessLog + (*AuthenticateRequest)(nil), // 13: management.AuthenticateRequest + (*PasswordRequest)(nil), // 14: management.PasswordRequest + (*PinRequest)(nil), // 15: management.PinRequest + (*AuthenticateResponse)(nil), // 16: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 17: management.SendStatusUpdateRequest + (*SendStatusUpdateResponse)(nil), // 18: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 19: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 20: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 21: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 22: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 23: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 24: management.ValidateSessionResponse + nil, // 25: management.PathTargetOptions.CustomHeadersEntry + (*timestamppb.Timestamp)(nil), // 26: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 27: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 25, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp - 8, // 1: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 26, // 2: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration - 1, // 3: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 24, // 4: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 5, // 5: management.PathMapping.options:type_name -> management.PathTargetOptions - 0, // 6: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 6, // 7: management.ProxyMapping.path:type_name -> management.PathMapping - 7, // 8: management.ProxyMapping.auth:type_name -> management.Authentication - 11, // 9: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 25, // 10: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 13, // 11: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 14, // 12: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 2, // 13: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 3, // 14: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 9, // 15: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 12, // 16: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 16, // 17: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 18, // 18: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 20, // 19: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 22, // 20: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 4, // 21: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 10, // 22: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 15, // 23: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 17, // 24: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 19, // 25: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 21, // 26: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 23, // 27: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 21, // [21:28] is the sub-list for method output_type - 14, // [14:21] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 26, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities + 9, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 27, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode + 25, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 27, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions + 0, // 8: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 7, // 9: management.ProxyMapping.path:type_name -> management.PathMapping + 8, // 10: management.ProxyMapping.auth:type_name -> management.Authentication + 12, // 11: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 26, // 12: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 14, // 13: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 15, // 14: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 2, // 15: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 16: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 10, // 17: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 13, // 18: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 17, // 19: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 19, // 20: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 21, // 21: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 23, // 22: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 23: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 11, // 24: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 16, // 25: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 18, // 26: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 20, // 27: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 22, // 28: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 24, // 29: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 23, // [23:30] is the sub-list for method output_type + 16, // [16:23] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -1980,7 +2107,7 @@ func file_proxy_service_proto_init() { } if !protoimpl.UnsafeEnabled { file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateRequest); i { + switch v := v.(*ProxyCapabilities); i { case 0: return &v.state case 1: @@ -1992,7 +2119,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateResponse); i { + switch v := v.(*GetMappingUpdateRequest); i { case 0: return &v.state case 1: @@ -2004,7 +2131,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathTargetOptions); i { + switch v := v.(*GetMappingUpdateResponse); i { case 0: return &v.state case 1: @@ -2016,7 +2143,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { + switch v := v.(*PathTargetOptions); i { case 0: return &v.state case 1: @@ -2028,7 +2155,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { + switch v := v.(*PathMapping); i { case 0: return &v.state case 1: @@ -2040,7 +2167,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { + switch v := v.(*Authentication); i { case 0: return &v.state case 1: @@ -2052,7 +2179,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { + switch v := v.(*ProxyMapping); i { case 0: return &v.state case 1: @@ -2064,7 +2191,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { + switch v := v.(*SendAccessLogRequest); i { case 0: return &v.state case 1: @@ -2076,7 +2203,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { + switch v := v.(*SendAccessLogResponse); i { case 0: return &v.state case 1: @@ -2088,7 +2215,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { + switch v := v.(*AccessLog); i { case 0: return &v.state case 1: @@ -2100,7 +2227,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { + switch v := v.(*AuthenticateRequest); i { case 0: return &v.state case 1: @@ -2112,7 +2239,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { + switch v := v.(*PasswordRequest); i { case 0: return &v.state case 1: @@ -2124,7 +2251,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { + switch v := v.(*PinRequest); i { case 0: return &v.state case 1: @@ -2136,7 +2263,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { + switch v := v.(*AuthenticateResponse); i { case 0: return &v.state case 1: @@ -2148,7 +2275,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*SendStatusUpdateRequest); i { case 0: return &v.state case 1: @@ -2160,7 +2287,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -2172,7 +2299,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*CreateProxyPeerRequest); i { case 0: return &v.state case 1: @@ -2184,7 +2311,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*CreateProxyPeerResponse); i { case 0: return &v.state case 1: @@ -2196,7 +2323,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*GetOIDCURLRequest); i { case 0: return &v.state case 1: @@ -2208,7 +2335,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*GetOIDCURLResponse); i { case 0: return &v.state case 1: @@ -2220,6 +2347,18 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state @@ -2232,19 +2371,20 @@ func file_proxy_service_proto_init() { } } } - file_proxy_service_proto_msgTypes[9].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[10].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), } - file_proxy_service_proto_msgTypes[13].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[16].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[14].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 22, + NumMessages: 23, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 195b60f01..457d12e85 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -27,12 +27,19 @@ service ProxyService { rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); } +// ProxyCapabilities describes what a proxy can handle. +message ProxyCapabilities { + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + optional bool supports_custom_ports = 1; +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. message GetMappingUpdateRequest { string proxy_id = 1; string version = 2; google.protobuf.Timestamp started_at = 3; string address = 4; + ProxyCapabilities capabilities = 5; } // GetMappingUpdateResponse contains zero or more ProxyMappings. @@ -61,6 +68,10 @@ message PathTargetOptions { google.protobuf.Duration request_timeout = 2; PathRewriteMode path_rewrite = 3; map custom_headers = 4; + // Send PROXY protocol v2 header to this backend. + bool proxy_protocol = 5; + // Idle timeout before a UDP session is reaped. + google.protobuf.Duration session_idle_timeout = 6; } message PathMapping { @@ -91,6 +102,10 @@ message ProxyMapping { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. bool rewrite_redirects = 9; + // Service mode: "http", "tcp", "udp", or "tls". + string mode = 10; + // For L4/TLS: the port the proxy listens on. + int32 listen_port = 11; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -117,6 +132,7 @@ message AccessLog { bool auth_success = 13; int64 bytes_upload = 14; int64 bytes_download = 15; + string protocol = 16; } message AuthenticateRequest {
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}