diff --git a/client/cmd/expose.go b/client/cmd/expose.go index 991d3ab86..1334617d8 100644 --- a/client/cmd/expose.go +++ b/client/cmd/expose.go @@ -22,20 +22,24 @@ import ( var pinRegexp = regexp.MustCompile(`^\d{6}$`) var ( - exposePin string - exposePassword string - exposeUserGroups []string - exposeDomain string - exposeNamePrefix string - exposeProtocol string + exposePin string + exposePassword string + exposeUserGroups []string + exposeDomain string + exposeNamePrefix string + exposeProtocol string + exposeExternalPort uint16 ) var exposeCmd = &cobra.Command{ - Use: "expose ", - Short: "Expose a local port via the NetBird reverse proxy", - Args: cobra.ExactArgs(1), - Example: "netbird expose --with-password safe-pass 8080", - RunE: exposeFn, + Use: "expose ", + Short: "Expose a local port via the NetBird reverse proxy", + Args: cobra.ExactArgs(1), + Example: ` netbird expose --with-password safe-pass 8080 + netbird expose --protocol tcp 5432 + netbird expose --protocol tcp --with-external-port 5433 5432 + netbird expose --protocol tls --with-custom-domain tls.example.com 4443`, + RunE: exposeFn, } func init() { @@ -44,7 +48,52 @@ func init() { exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)") exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)") exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)") - exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)") + exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)") + exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)") +} + +// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags. +func isClusterProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp", "tls": + return true + default: + return false + } +} + +// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP) +// where domain display doesn't apply. TLS uses SNI so it has a domain. +func isPortBasedProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp": + return true + default: + return false + } +} + +// extractPort returns the port portion of a URL like "tcp://host:12345", or +// falls back to the given default formatted as a string. +func extractPort(serviceURL string, fallback uint16) string { + u := serviceURL + if idx := strings.Index(u, "://"); idx != -1 { + u = u[idx+3:] + } + if i := strings.LastIndex(u, ":"); i != -1 { + if p := u[i+1:]; p != "" { + return p + } + } + return strconv.FormatUint(uint64(fallback), 10) +} + +// resolveExternalPort returns the effective external port, defaulting to the target port. +func resolveExternalPort(targetPort uint64) uint16 { + if exposeExternalPort != 0 { + return exposeExternalPort + } + return uint16(targetPort) } func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { @@ -57,7 +106,15 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } if !isProtocolValid(exposeProtocol) { - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + } + + if isClusterProtocol(exposeProtocol) { + if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 { + return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol) + } + } else if cmd.Flags().Changed("with-external-port") { + return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol) } if exposePin != "" && !pinRegexp.MatchString(exposePin) { @@ -76,7 +133,12 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } func isProtocolValid(exposeProtocol string) bool { - return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https" + switch strings.ToLower(exposeProtocol) { + case "http", "https", "tcp", "udp", "tls": + return true + default: + return false + } } func exposeFn(cmd *cobra.Command, args []string) error { @@ -123,7 +185,7 @@ func exposeFn(cmd *cobra.Command, args []string) error { return err } - stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{ + req := &proto.ExposeServiceRequest{ Port: uint32(port), Protocol: protocol, Pin: exposePin, @@ -131,7 +193,12 @@ func exposeFn(cmd *cobra.Command, args []string) error { UserGroups: exposeUserGroups, Domain: exposeDomain, NamePrefix: exposeNamePrefix, - }) + } + if isClusterProtocol(exposeProtocol) { + req.ListenPort = uint32(resolveExternalPort(port)) + } + + stream, err := client.ExposeService(ctx, req) if err != nil { return fmt.Errorf("expose service: %w", err) } @@ -149,8 +216,14 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { return proto.ExposeProtocol_EXPOSE_HTTP, nil case "https": return proto.ExposeProtocol_EXPOSE_HTTPS, nil + case "tcp": + return proto.ExposeProtocol_EXPOSE_TCP, nil + case "udp": + return proto.ExposeProtocol_EXPOSE_UDP, nil + case "tls": + return proto.ExposeProtocol_EXPOSE_TLS, nil default: - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) } } @@ -160,20 +233,33 @@ func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServ return fmt.Errorf("receive expose event: %w", err) } - switch e := event.Event.(type) { - case *proto.ExposeServiceEvent_Ready: - cmd.Println("Service exposed successfully!") - cmd.Printf(" Name: %s\n", e.Ready.ServiceName) - cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl) - cmd.Printf(" Domain: %s\n", e.Ready.Domain) - cmd.Printf(" Protocol: %s\n", exposeProtocol) - cmd.Printf(" Port: %d\n", port) - cmd.Println() - cmd.Println("Press Ctrl+C to stop exposing.") - return nil - default: + ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready) + if !ok { return fmt.Errorf("unexpected expose event: %T", event.Event) } + printExposeReady(cmd, ready.Ready, port) + return nil +} + +func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) { + cmd.Println("Service exposed successfully!") + cmd.Printf(" Name: %s\n", r.ServiceName) + if r.ServiceUrl != "" { + cmd.Printf(" URL: %s\n", r.ServiceUrl) + } + if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) { + cmd.Printf(" Domain: %s\n", r.Domain) + } + cmd.Printf(" Protocol: %s\n", exposeProtocol) + cmd.Printf(" Internal: %d\n", port) + if isClusterProtocol(exposeProtocol) { + cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port))) + } + if r.PortAutoAssigned && exposeExternalPort != 0 { + cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort) + } + cmd.Println() + cmd.Println("Press Ctrl+C to stop exposing.") } func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error { diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index 8cd93685e..c59a1a7bd 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -12,9 +12,10 @@ const renewTimeout = 10 * time.Second // Response holds the response from exposing a service. type Response struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } type Request struct { @@ -25,6 +26,7 @@ type Request struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ManagementClient interface { diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go index 7e12d0513..bff4f2ce7 100644 --- a/client/internal/expose/request.go +++ b/client/internal/expose/request.go @@ -15,6 +15,7 @@ func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { UserGroups: req.UserGroups, Domain: req.Domain, NamePrefix: req.NamePrefix, + ListenPort: uint16(req.ListenPort), } } @@ -27,13 +28,15 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest { Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: req.ListenPort, } } func fromClientExposeResponse(response *mgm.ExposeResponse) *Response { return &Response{ - ServiceName: response.ServiceName, - Domain: response.Domain, - ServiceURL: response.ServiceURL, + ServiceName: response.ServiceName, + Domain: response.Domain, + ServiceURL: response.ServiceURL, + PortAutoAssigned: response.PortAutoAssigned, } } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fd3c18f56..fa0b2f93b 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -95,6 +95,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -104,12 +105,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -5741,6 +5744,7 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5824,6 +5828,13 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceEvent struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Event: @@ -5891,12 +5902,13 @@ type ExposeServiceEvent_Ready struct { func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {} type ExposeServiceReady struct { - state protoimpl.MessageState `protogen:"open.v1"` - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExposeServiceReady) Reset() { @@ -5950,6 +5962,13 @@ func (x *ExposeServiceReady) GetDomain() string { return "" } +func (x *ExposeServiceReady) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -6499,7 +6518,7 @@ const file_daemon_proto_rawDesc = "" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + - "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\x87\x02\n" + "\x14ExposeServiceRequest\x12\x12\n" + "\x04port\x18\x01 \x01(\rR\x04port\x122\n" + "\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" + @@ -6509,15 +6528,18 @@ const file_daemon_proto_rawDesc = "" + "userGroups\x12\x16\n" + "\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" + "\vname_prefix\x18\a \x01(\tR\n" + - "namePrefix\"Q\n" + + "namePrefix\x12\x1f\n" + + "\vlisten_port\x18\b \x01(\rR\n" + + "listenPort\"Q\n" + "\x12ExposeServiceEvent\x122\n" + "\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" + - "\x05event\"p\n" + + "\x05event\"\x9e\x01\n" + "\x12ExposeServiceReady\x12!\n" + "\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + - "\x06domain\x18\x03 \x01(\tR\x06domain*b\n" + + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6526,14 +6548,16 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a*S\n" + + "\x05TRACE\x10\a*c\n" + "\x0eExposeProtocol\x12\x0f\n" + "\vEXPOSE_HTTP\x10\x00\x12\x10\n" + "\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" + "\n" + "EXPOSE_TCP\x10\x02\x12\x0e\n" + "\n" + - "EXPOSE_UDP\x10\x032\xfc\x15\n" + + "EXPOSE_UDP\x10\x03\x12\x0e\n" + + "\n" + + "EXPOSE_TLS\x10\x042\xfc\x15\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index efafe3af7..89302c8c3 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -821,6 +821,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -831,6 +832,7 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceEvent { @@ -843,4 +845,5 @@ message ExposeServiceReady { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } diff --git a/client/server/server.go b/client/server/server.go index 1d83366ca..7c1e70692 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1378,9 +1378,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon if err := srv.Send(&proto.ExposeServiceEvent{ Event: &proto.ExposeServiceEvent_Ready{ Ready: &proto.ExposeServiceReady{ - ServiceName: result.ServiceName, - ServiceUrl: result.ServiceURL, - Domain: result.Domain, + ServiceName: result.ServiceName, + ServiceUrl: result.ServiceURL, + Domain: result.Domain, + PortAutoAssigned: result.PortAutoAssigned, }, }, }); err != nil { diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 0bcc59b68..619a34684 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -10,6 +10,15 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// AccessLogProtocol identifies the transport protocol of an access log entry. +type AccessLogProtocol string + +const ( + AccessLogProtocolHTTP AccessLogProtocol = "http" + AccessLogProtocolTCP AccessLogProtocol = "tcp" + AccessLogProtocolUDP AccessLogProtocol = "udp" +) + type AccessLogEntry struct { ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` @@ -22,10 +31,11 @@ type AccessLogEntry struct { Duration time.Duration `gorm:"index"` StatusCode int `gorm:"index"` Reason string - UserId string `gorm:"index"` - AuthMethodUsed string `gorm:"index"` - BytesUpload int64 `gorm:"index"` - BytesDownload int64 `gorm:"index"` + UserId string `gorm:"index"` + AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` + Protocol AccessLogProtocol `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -43,17 +53,22 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.AccountID = serviceLog.GetAccountId() a.BytesUpload = serviceLog.GetBytesUpload() a.BytesDownload = serviceLog.GetBytesDownload() + a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { - if ip, err := netip.ParseAddr(sourceIP); err == nil { - a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice()) + if addr, err := netip.ParseAddr(sourceIP); err == nil { + addr = addr.Unmap() + a.GeoLocation.ConnectionIP = net.IP(addr.AsSlice()) } } - if !serviceLog.GetAuthSuccess() { - a.Reason = "Authentication failed" - } else if serviceLog.GetResponseCode() >= 400 { - a.Reason = "Request failed" + // Only set reason for HTTP entries. L4 entries have no auth or status code. + if a.Protocol == "" || a.Protocol == AccessLogProtocolHTTP { + if !serviceLog.GetAuthSuccess() { + a.Reason = "Authentication failed" + } else if serviceLog.GetResponseCode() >= 400 { + a.Reason = "Request failed" + } } } @@ -90,6 +105,12 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { cityName = &a.GeoLocation.CityName } + var protocol *string + if a.Protocol != "" { + p := string(a.Protocol) + protocol = &p + } + return &api.ProxyAccessLog{ Id: a.ID, ServiceId: a.ServiceID, @@ -107,5 +128,6 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { CityName: cityName, BytesUpload: a.BytesUpload, BytesDownload: a.BytesDownload, + Protocol: protocol, } } diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index 83fd669af..861d026a7 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -14,6 +14,9 @@ type Domain struct { TargetCluster string // The proxy cluster this domain should be validated against Type Type `gorm:"-"` Validated bool + // SupportsCustomPorts is populated at query time for free domains from the + // proxy cluster capabilities. Not persisted. + SupportsCustomPorts *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 2fbcdd5b8..d26a6a418 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -42,10 +42,11 @@ func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainToApi(d *domain.Domain) api.ReverseProxyDomain { resp := api.ReverseProxyDomain{ - Domain: d.Domain, - Id: d.ID, - Type: domainTypeToApi(d.Type), - Validated: d.Validated, + Domain: d.Domain, + Id: d.ID, + Type: domainTypeToApi(d.Type), + Validated: d.Validated, + SupportsCustomPorts: d.SupportsCustomPorts, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 8bbc98726..813027ea2 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -33,11 +33,16 @@ type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) } +type clusterCapabilities interface { + ClusterSupportsCustomPorts(clusterAddr string) *bool +} + type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + clusterCapabilities clusterCapabilities + permissionsManager permissions.Manager accountManager account.Manager } @@ -51,6 +56,11 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } +// SetClusterCapabilities sets the cluster capabilities provider for domain queries. +func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { + m.clusterCapabilities = caps +} + func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -80,24 +90,32 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d }).Debug("getting domains with proxy allow list") for _, cluster := range allowList { - ret = append(ret, &domain.Domain{ + d := &domain.Domain{ Domain: cluster, AccountID: accountID, Type: domain.TypeFree, Validated: true, - }) + } + if m.clusterCapabilities != nil { + d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) + } + ret = append(ret, d) } // Add custom domains. for _, d := range domains { - ret = append(ret, &domain.Domain{ + cd := &domain.Domain{ ID: d.ID, Domain: d.Domain, AccountID: accountID, TargetCluster: d.TargetCluster, Type: domain.TypeCustom, Validated: d.Validated, - }) + } + if m.clusterCapabilities != nil && d.TargetCluster != "" { + cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + } + ret = append(ret, cd) } return ret, nil @@ -298,7 +316,7 @@ func extractClusterFromCustomDomains(domain string, customDomains []*domain.Doma // It matches the domain suffix against available clusters and returns the matching cluster. func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { for _, cluster := range availableClusters { - if strings.HasSuffix(domain, "."+cluster) { + if domain == cluster || strings.HasSuffix(domain, "."+cluster) { return cluster, true } } diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 15f2f9f54..67a8e74fa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -33,4 +33,5 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string + ClusterSupportsCustomPorts(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index e5b3e9886..acb49c45b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,6 +72,11 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } +// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. +func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) +} + // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index d9645ba88..b07a21122 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -144,6 +144,20 @@ func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) } +// ClusterSupportsCustomPorts mocks base method. +func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) +} + // GetProxiesForCluster mocks base method. func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index b420f22a8..39fd7e3ae 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -22,7 +22,7 @@ type Manager interface { GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) - RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error - StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error + RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 727b2c7de..bdc1f3e65 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter } // RenewServiceFromPeer mocks base method. -func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer. -func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID) } // SetCertificateIssuedAt mocks base method. @@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca } // StopServiceFromPeer mocks base method. -func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // StopServiceFromPeer indicates an expected call of StopServiceFromPeer. -func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID) } // UpdateService mocks base method. diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index f28b633b8..c53219d2e 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -11,19 +11,22 @@ import ( domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) type handler struct { - manager rpservice.Manager + manager rpservice.Manager + permissionsManager permissions.Manager } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) { h := &handler{ - manager: manager, + manager: manager, + permissionsManager: permissionsManager, } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index c831b4a22..6ff8343b9 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) { // Create a non-expired service resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8081, - Protocol: "http", + Port: 8081, + Mode: "http", }) require.NoError(t, err) @@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Delete the service before reaping - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) // Reaping should handle the already-deleted service gracefully @@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) { for i := range 5 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) { t.Run("renew succeeds for active service", func(t *testing.T) { resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8082, - Protocol: "http", + Port: 8082, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, lookupErr) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) }) t.Run("renew fails for nonexistent domain", func(t *testing.T) { - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) assert.Contains(t, err.Error(), "no active expose session") }) @@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) { assert.Equal(t, int64(0), count) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8083, - Protocol: "http", + Port: 8083, + Mode: "http", }) require.NoError(t, err) @@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) { for i := range maxExposesPerPeer { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8090 + i, - Protocol: "http", + Port: uint16(8090 + i), + Mode: "http", }) require.NoError(t, err, "expose %d should succeed", i) } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9999, - Protocol: "http", + Port: 9999, + Mode: "http", }) require.Error(t, err) assert.Contains(t, err.Error(), "maximum number of active expose sessions") @@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8086, - Protocol: "http", + Port: 8086, + Mode: "http", }) require.NoError(t, err) @@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) { expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Renew it before the reaper runs - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) // Reaper should skip it because the re-check sees a fresh timestamp @@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) { require.NoError(t, err, "renewed service should survive reaping") } +// resolveServiceIDByDomain looks up a service ID by domain in tests. +func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), domain) + require.NoError(t, err) + return svc.ID +} + // expireEphemeralService backdates meta_last_renewed_at to force expiration. func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { t.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go new file mode 100644 index 000000000..c7a61ddcf --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -0,0 +1,582 @@ +package manager + +import ( + "context" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +const testCluster = "test-cluster" + +func boolPtr(v bool) *bool { return &v } + +// setupL4Test creates a manager with a mock proxy controller for L4 port tests. +func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) { + t.Helper() + + ctrl := gomock.NewController(t) + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + CreatedBy: testUserID, + Settings: &types.Settings{ + PeerExposeEnabled: true, + PeerExposeGroups: []string{testGroupID}, + }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, + Peers: map[string]*nbpeer.Peer{ + testPeerID: { + ID: testPeerID, + AccountID: testAccountID, + Key: "test-key", + DNSLabel: "test-peer", + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, + }, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + AccountID: testAccountID, + Name: "Expose Group", + }, + }, + }) + require.NoError(t, err) + + err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) + require.NoError(t, err) + + mockCtrl := proxy.NewMockController(ctrl) + mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() + mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + + accountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + }, + } + + mgr := &Manager{ + store: testStore, + accountManager: accountMgr, + permissionsManager: permissions.NewManager(testStore), + proxyController: mockCtrl, + clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + + return mgr, testStore, mockCtrl +} + +// seedService creates a service directly in the store for test setup. +func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service { + t.Helper() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: name, + Mode: protocol, + Domain: domain, + ProxyCluster: cluster, + ListenPort: port, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + err := s.CreateService(context.Background(), svc) + require.NoError(t, err) + return svc +} + +func TestPortConflict_TCPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-tcp", + Mode: "tcp", + Domain: "conflicting-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP+TCP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_UDPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-udp", + Mode: "udp", + Domain: "conflicting-udp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "UDP+UDP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tls", + Mode: "tls", + Domain: "app2.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)") +} + +func TestPortConflict_TLSSamePortSameDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "duplicate-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TLS+TLS on same domain should be rejected") + assert.Contains(t, err.Error(), "domain already taken") +} + +func TestPortConflict_TLSAndTCPSamePort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)") +} + +func TestAutoAssign_TCPNoListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax) + assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set") +} + +func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it") + assert.Contains(t, err.Error(), "custom ports") +} + +func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability") + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS") +} + +func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden") + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range", svc.ListenPort) + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_AvoidsExistingPorts(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existingPort := uint16(20000) + seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: "auto-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing") + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported") + assert.False(t, svc.PortAutoAssigned) +} + +func TestUpdate_PreservesExistingListenPort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc-renamed", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0") +} + +func TestUpdate_AllowsPortChange(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 54321, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied") +} + +func TestCreateServiceFromPeer_TCP(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + assert.NotEmpty(t, resp.ServiceName) + assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain") + assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports") + assert.Contains(t, resp.ServiceURL, "tcp://") +} + +func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + ListenPort: 15432, + }) + require.NoError(t, err) + + assert.False(t, resp.PortAutoAssigned) + assert.Contains(t, resp.ServiceURL, ":15432") +} + +func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + // When no explicit listen port, defaults to target port + assert.Contains(t, resp.ServiceURL, ":5432") + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TLS(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 443, + Mode: "tls", + }) + require.NoError(t, err) + + assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain") + assert.Contains(t, resp.ServiceURL, "tls://") + assert.Contains(t, resp.ServiceURL, ":443") + // TLS always keeps its port (not port-based protocol for auto-assign) + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + }) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + // Renew after stop should fail + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.Error(t, err) +} + +func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + Pin: "123456", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index cae3d3bda..c40961fdc 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "math/rand/v2" + "os" "slices" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -23,6 +25,45 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + defaultAutoAssignPortMin uint16 = 10000 + defaultAutoAssignPortMax uint16 = 49151 + + // EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN" + // EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX" +) + +var ( + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax +) + +func init() { + autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin) + autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax) + if autoAssignPortMin > autoAssignPortMax { + log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults", + EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax) + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax + } +} + +func portFromEnv(key string, fallback uint16) uint16 { + val := os.Getenv(key) + if val == "" { + return fallback + } + n, err := strconv.ParseUint(val, 10, 16) + if err != nil { + log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err) + return fallback + } + return uint16(n) +} + const unknownHostPlaceholder = "unknown" // ClusterDeriver derives the proxy cluster from a domain. @@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s * return fmt.Errorf("unknown target type: %s", target.TargetType) } } + return nil } @@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return nil } -func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil { + if svc.Domain != "" { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - - if err := transaction.CreateService(ctx, service); err != nil { - return fmt.Errorf("failed to create service: %w", err) - } - - return nil - }) -} - -// persistNewEphemeralService creates an ephemeral service inside a single transaction -// that also enforces the duplicate and per-peer limit checks atomically. -// The count and exists queries use FOR UPDATE locking to serialize concurrent creates -// for the same peer, preventing the per-peer limit from being bypassed. -func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - // Lock the peer row to serialize concurrent creates for the same peer. - // Without this, when no ephemeral rows exist yet, FOR UPDATE on the services - // table returns no rows and acquires no locks, allowing concurrent inserts - // to bypass the per-peer limit. - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { - return fmt.Errorf("lock peer row: %w", err) - } - - exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) - if err != nil { - return fmt.Errorf("check existing expose: %w", err) - } - if exists { - return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") - } - - count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) - if err != nil { - return fmt.Errorf("count peer exposes: %w", err) - } - if count >= int64(maxExposesPerPeer) { - return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) - } - - if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { return err } @@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee }) } +// ensureL4Port auto-assigns a listen port when needed and validates cluster support. +func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) + if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { + if svc.Source != service.SourceEphemeral { + return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) + } + svc.ListenPort = 0 + } + if svc.ListenPort == 0 { + port, err := m.assignPort(ctx, tx, svc.ProxyCluster) + if err != nil { + return err + } + svc.ListenPort = port + svc.PortAutoAssigned = true + } + return nil +} + +// checkPortConflict rejects L4 services that would conflict on the same listener. +// For TCP/UDP: unique per cluster+protocol+port. +// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports). +// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked: +// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener. +func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 { + return nil + } + + existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort) + if err != nil { + return fmt.Errorf("query port conflicts: %w", err) + } + for _, s := range existing { + if s.ID == svc.ID { + continue + } + // TLS services on the same port are allowed if they have different domains (SNI routing) + if svc.Mode == service.ModeTLS && s.Domain != svc.Domain { + continue + } + return status.Errorf(status.AlreadyExists, + "%s port %d is already in use by service %q on cluster %s", + svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster) + } + + return nil +} + +// assignPort picks a random available port on the cluster within the auto-assign range. +func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) { + services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster) + if err != nil { + return 0, fmt.Errorf("query cluster ports: %w", err) + } + + occupied := make(map[uint16]struct{}, len(services)) + for _, s := range services { + if s.ListenPort > 0 { + occupied[s.ListenPort] = struct{}{} + } + } + + portRange := int(autoAssignPortMax-autoAssignPortMin) + 1 + for range 100 { + port := autoAssignPortMin + uint16(rand.IntN(portRange)) + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + for port := autoAssignPortMin; port <= autoAssignPortMax; port++ { + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster) +} + +// persistNewEphemeralService creates an ephemeral service inside a single transaction +// that also enforces the duplicate and per-peer limit checks atomically. +// The count and exists queries use FOR UPDATE locking to serialize concurrent creates +// for the same peer, preventing the per-peer limit from being bypassed. +func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { + return err + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { + return err + } + + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + +func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error { + // Lock the peer row to serialize concurrent creates for the same peer. + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { + return fmt.Errorf("lock peer row: %w", err) + } + + exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) + if err != nil { + return fmt.Errorf("check existing expose: %w", err) + } + if exists { + return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + + count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return fmt.Errorf("count peer exposes: %w", err) + } + if count >= int64(maxExposesPerPeer) { + return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + + return nil +} + +// checkDomainAvailable checks that no other service already uses this domain. func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing service: %w", err) + return fmt.Errorf("check existing service: %w", err) } return nil } @@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return err } + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } + updateInfo.oldCluster = existingService.ProxyCluster updateInfo.domainChanged = existingService.Domain != service.Domain @@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se m.preserveExistingAuthSecrets(service, existingService) m.preserveServiceMetadata(service, existingService) + m.preserveListenPort(service, existingService) updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled + if err := m.ensureL4Port(ctx, transaction, service); err != nil { + return err + } + if err := m.checkPortConflict(ctx, transaction, service); err != nil { + return err + } if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { return err } - if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return &updateInfo, err } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain) + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) } else { - service.ProxyCluster = newCluster + svc.ProxyCluster = newCluster } } return nil } +// validateProtocolChange rejects mode changes on update. +// Only empty<->HTTP is allowed; all other transitions are rejected. +func validateProtocolChange(oldMode, newMode string) error { + if newMode == "" || newMode == oldMode { + return nil + } + if isHTTPFamily(oldMode) && isHTTPFamily(newMode) { + return nil + } + return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode) +} + +func isHTTPFamily(mode string) bool { + return mode == "" || mode == "http" +} + func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && @@ -388,6 +564,13 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv service.SessionPublicKey = existingService.SessionPublicKey } +func (m *Manager) preserveListenPort(svc, existing *service.Service) { + if existing.ListenPort > 0 && svc.ListenPort == 0 { + svc.ListenPort = existing.ListenPort + svc.PortAutoAssigned = existing.PortAutoAssigned + } +} + func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { oidcCfg := m.proxyController.GetOIDCValidationConfig() @@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group") } +func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) { + return m.buildRandomDomain(serviceName) +} + // CreateServiceFromPeer creates a service initiated by a peer expose request. // It validates the request, checks expose permissions, enforces the per-peer limit, // creates the service, and tracks it for TTL-based reaping. @@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s svc.Source = service.SourceEphemeral if svc.Domain == "" { - domain, err := m.buildRandomDomain(svc.Name) + domain, err := m.resolveDefaultDomain(svc.Name) if err != nil { - return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err) + return nil, err } svc.Domain = domain } @@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) + serviceURL := "https://" + svc.Domain + if service.IsL4Protocol(svc.Mode) { + serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + } + return &service.ExposeServiceResponse{ - ServiceName: svc.Name, - ServiceURL: "https://" + svc.Domain, - Domain: svc.Domain, + ServiceName: svc.Name, + ServiceURL: serviceURL, + Domain: svc.Domain, + PortAutoAssigned: svc.PortAutoAssigned, }, nil } @@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr return groupIDs, nil } -func (m *Manager) buildRandomDomain(name string) (string, error) { +func (m *Manager) getDefaultClusterDomain() (string, error) { if m.clusterDeriver == nil { - return "", fmt.Errorf("unable to get random domain") + return "", fmt.Errorf("unable to get cluster domain") } clusterDomains := m.clusterDeriver.GetClusterDomains() if len(clusterDomains) == 0 { - return "", fmt.Errorf("no cluster domains found for service %s", name) + return "", fmt.Errorf("no cluster domains available") } - index := rand.IntN(len(clusterDomains)) - domain := name + "." + clusterDomains[index] - return domain, nil + return clusterDomains[rand.IntN(len(clusterDomains))], nil +} + +func (m *Manager) buildRandomDomain(name string) (string, error) { + domain, err := m.getDefaultClusterDomain() + if err != nil { + return "", err + } + return name + "." + domain, nil } // RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service. -func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - return m.store.RenewEphemeralService(ctx, accountID, peerID, domain) +func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID) } // StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB. -func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { - log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) +func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err) return err } return nil } -// deleteServiceFromPeer deletes a peer-initiated service identified by domain. +// deleteServiceFromPeer deletes a peer-initiated service identified by service ID. // When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. -func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { - svc, err := m.lookupPeerService(ctx, accountID, peerID, domain) - if err != nil { - return err - } - +func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error { activityCode := activity.PeerServiceUnexposed if expired { activityCode = activity.PeerServiceExposeExpired } - return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode) -} - -// lookupPeerService finds a peer-initiated service by domain and validates ownership. -func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { - svc, err := m.store.GetServiceByDomain(ctx, domain) - if err != nil { - return nil, err - } - - if svc.Source != service.SourceEphemeral { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") - } - - if svc.SourcePeer != peerID { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") - } - - return svc, nil + return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode) } func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index ba4e1c805..d23c91017 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 80, - Protocol: "http", - Domain: "example.com", + Port: 80, + Mode: "http", + Domain: "example.com", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) { require.NoError(t, err) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 0, - Protocol: "http", + Port: 0, + Mode: "http", } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) { }{ { name: "valid http request", - req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"}, wantErr: "", }, { - name: "valid https request with pin", - req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, - wantErr: "", + name: "https mode rejected", + req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"}, + wantErr: "unsupported mode", }, { name: "port zero rejected", - req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"}, wantErr: "port must be between 1 and 65535", }, { - name: "negative port rejected", - req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "port above 65535 rejected", - req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "unsupported protocol", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, - wantErr: "unsupported protocol", + name: "unsupported mode", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"}, + wantErr: "unsupported mode", }, { name: "invalid pin format", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"}, wantErr: "invalid pin", }, { name: "pin too short", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"}, wantErr: "invalid pin", }, { name: "valid 6-digit pin", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"}, wantErr: "", }, { name: "empty user group name", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}}, wantErr: "user group name cannot be empty", }, { name: "invalid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"}, wantErr: "invalid name prefix", }, { name: "valid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"}, wantErr: "", }, } @@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { // First create a service req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - // Delete by domain using unexported method - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false) require.NoError(t, err) // Verify service is deleted @@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { }) t.Run("expire uses correct activity", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true) require.NoError(t, err) }) } @@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) _, err = testStore.GetServiceByDomain(ctx, resp.Domain) @@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete") _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9090, - Protocol: "http", + Port: 9090, + Mode: "http", }) assert.NoError(t, err, "new expose should succeed after API delete") } @@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) { for i := range 3 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) { ctx := context.Background() t.Run("renews tracked expose", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) }) t.Run("fails for untracked domain", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) }) } @@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) assert.Len(t, targets, 0, "All targets should be deleted when service is deleted") } + +func TestValidateProtocolChange(t *testing.T) { + tests := []struct { + name string + oldP string + newP string + wantErr bool + }{ + {"empty to http", "", "http", false}, + {"http to http", "http", "http", false}, + {"same protocol", "tcp", "tcp", false}, + {"empty new proto", "tcp", "", false}, + {"http to tcp", "http", "tcp", true}, + {"tcp to udp", "tcp", "udp", true}, + {"tls to http", "tls", "http", true}, + {"udp to tls", "udp", "tls", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProtocolChange(tt.oldP, tt.newP) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot change mode") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index bfad7fe9a..623284404 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -34,6 +34,7 @@ const ( ) type Status string +type TargetType string const ( StatusPending Status = "pending" @@ -43,34 +44,36 @@ const ( StatusCertificateFailed Status = "certificate_failed" StatusError Status = "error" - TargetTypePeer = "peer" - TargetTypeHost = "host" - TargetTypeDomain = "domain" - TargetTypeSubnet = "subnet" + TargetTypePeer TargetType = "peer" + TargetTypeHost TargetType = "host" + TargetTypeDomain TargetType = "domain" + TargetTypeSubnet TargetType = "subnet" SourcePermanent = "permanent" SourceEphemeral = "ephemeral" ) type TargetOptions struct { - SkipTLSVerify bool `json:"skip_tls_verify"` - RequestTimeout time.Duration `json:"request_timeout,omitempty"` - PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` - CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` + SkipTLSVerify bool `json:"skip_tls_verify"` + RequestTimeout time.Duration `json:"request_timeout,omitempty"` + SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` + PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` } type Target struct { - ID uint `gorm:"primaryKey" json:"-"` - AccountID string `gorm:"index:idx_target_account;not null" json:"-"` - ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` - Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored - Port int `gorm:"index:idx_target_port" json:"port"` - Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` - TargetId string `gorm:"index:idx_target_id" json:"target_id"` - TargetType string `gorm:"index:idx_target_type" json:"target_type"` - Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` - Options TargetOptions `gorm:"embedded" json:"options"` + ID uint `gorm:"primaryKey" json:"-"` + AccountID string `gorm:"index:idx_target_account;not null" json:"-"` + ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` + Path *string `json:"path,omitempty"` + Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Port uint16 `gorm:"index:idx_target_port" json:"port"` + Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` + TargetId string `gorm:"index:idx_target_id" json:"target_id"` + TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"` + Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` + Options TargetOptions `gorm:"embedded" json:"options"` + ProxyProtocol bool `json:"proxy_protocol"` } type PasswordAuthConfig struct { @@ -146,23 +149,10 @@ type Service struct { SessionPublicKey string `gorm:"column:session_public_key"` Source string `gorm:"default:'permanent';index:idx_service_source_peer"` SourcePeer string `gorm:"index:idx_service_source_peer"` -} - -func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { - for _, target := range targets { - target.AccountID = accountID - } - - s := &Service{ - AccountID: accountID, - Name: name, - Domain: domain, - ProxyCluster: proxyCluster, - Targets: targets, - Enabled: enabled, - } - s.InitNewRecord() - return s + // Mode determines the service type: "http", "tcp", "udp", or "tls". + Mode string `gorm:"default:'http'"` + ListenPort uint16 + PortAutoAssigned bool } // InitNewRecord generates a new unique ID and resets metadata for a newly created @@ -177,21 +167,17 @@ func (s *Service) InitNewRecord() { } func (s *Service) ToAPIResponse() *api.Service { - s.Auth.ClearSecrets() - authConfig := api.ServiceAuthConfig{} if s.Auth.PasswordAuth != nil { authConfig.PasswordAuth = &api.PasswordAuthConfig{ - Enabled: s.Auth.PasswordAuth.Enabled, - Password: s.Auth.PasswordAuth.Password, + Enabled: s.Auth.PasswordAuth.Enabled, } } if s.Auth.PinAuth != nil { authConfig.PinAuth = &api.PINAuthConfig{ Enabled: s.Auth.PinAuth.Enabled, - Pin: s.Auth.PinAuth.Pin, } } @@ -208,13 +194,18 @@ func (s *Service) ToAPIResponse() *api.Service { st := api.ServiceTarget{ Path: target.Path, Host: &target.Host, - Port: target.Port, + Port: int(target.Port), Protocol: api.ServiceTargetProtocol(target.Protocol), TargetId: target.TargetId, TargetType: api.ServiceTargetTargetType(target.TargetType), Enabled: target.Enabled, } - st.Options = targetOptionsToAPI(target.Options) + opts := targetOptionsToAPI(target.Options) + if opts == nil { + opts = &api.ServiceTargetOptions{} + } + opts.ProxyProtocol = &target.ProxyProtocol + st.Options = opts apiTargets = append(apiTargets, st) } @@ -227,6 +218,9 @@ func (s *Service) ToAPIResponse() *api.Service { meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt } + mode := api.ServiceMode(s.Mode) + listenPort := int(s.ListenPort) + resp := &api.Service{ Id: s.ID, Name: s.Name, @@ -237,6 +231,9 @@ func (s *Service) ToAPIResponse() *api.Service { RewriteRedirects: &s.RewriteRedirects, Auth: authConfig, Meta: meta, + Mode: &mode, + ListenPort: &listenPort, + PortAutoAssigned: &s.PortAutoAssigned, } if s.ProxyCluster != "" { @@ -247,37 +244,7 @@ func (s *Service) ToAPIResponse() *api.Service { } func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { - pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) - for _, target := range s.Targets { - if !target.Enabled { - continue - } - - // TODO: Make path prefix stripping configurable per-target. - // Currently the matching prefix is baked into the target URL path, - // so the proxy strips-then-re-adds it (effectively a no-op). - targetURL := url.URL{ - Scheme: target.Protocol, - Host: target.Host, - Path: "/", // TODO: support service path - } - if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port)) - } - - path := "/" - if target.Path != nil { - path = *target.Path - } - - pm := &proto.PathMapping{ - Path: path, - Target: targetURL.String(), - } - - pm.Options = targetOptionsToProto(target.Options) - pathMappings = append(pathMappings, pm) - } + pathMappings := s.buildPathMappings() auth := &proto.Authentication{ SessionKey: s.SessionPublicKey, @@ -306,9 +273,58 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf AccountId: s.AccountID, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, + Mode: s.Mode, + ListenPort: int32(s.ListenPort), //nolint:gosec } } +// buildPathMappings constructs PathMapping entries from targets. +// For HTTP/HTTPS, each target becomes a path-based route with a full URL. +// For L4/TLS, a single target maps to a host:port address. +func (s *Service) buildPathMappings() []*proto.PathMapping { + pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) + for _, target := range s.Targets { + if !target.Enabled { + continue + } + + if IsL4Protocol(s.Mode) { + pm := &proto.PathMapping{ + Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)), + } + opts := l4TargetOptionsToProto(target) + if opts != nil { + pm.Options = opts + } + pathMappings = append(pathMappings, pm) + continue + } + + // HTTP/HTTPS: build full URL + targetURL := url.URL{ + Scheme: target.Protocol, + Host: target.Host, + Path: "/", + } + if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { + targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + } + + path := "/" + if target.Path != nil { + path = *target.Path + } + + pm := &proto.PathMapping{ + Path: path, + Target: targetURL.String(), + } + pm.Options = targetOptionsToProto(target.Options) + pathMappings = append(pathMappings, pm) + } + return pathMappings +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: @@ -325,8 +341,8 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { // isDefaultPort reports whether port is the standard default for the given scheme // (443 for https, 80 for http). -func isDefaultPort(scheme string, port int) bool { - return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) +func isDefaultPort(scheme string, port uint16) bool { + return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80) } // PathRewriteMode controls how the request path is rewritten before forwarding. @@ -346,7 +362,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode { } func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { - if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { + if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { return nil } apiOpts := &api.ServiceTargetOptions{} @@ -357,6 +373,10 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { s := opts.RequestTimeout.String() apiOpts.RequestTimeout = &s } + if opts.SessionIdleTimeout != 0 { + s := opts.SessionIdleTimeout.String() + apiOpts.SessionIdleTimeout = &s + } if opts.PathRewrite != "" { pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite) apiOpts.PathRewrite = &pr @@ -382,6 +402,23 @@ func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { return popts } +// l4TargetOptionsToProto converts L4-relevant target options to proto. +func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions { + if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 { + return nil + } + opts := &proto.PathTargetOptions{ + ProxyProtocol: target.ProxyProtocol, + } + if target.Options.RequestTimeout > 0 { + opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout) + } + if target.Options.SessionIdleTimeout > 0 { + opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout) + } + return opts +} + func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) { var opts TargetOptions if o.SkipTlsVerify != nil { @@ -394,6 +431,13 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, } opts.RequestTimeout = d } + if o.SessionIdleTimeout != nil { + d, err := time.ParseDuration(*o.SessionIdleTimeout) + if err != nil { + return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err) + } + opts.SessionIdleTimeout = d + } if o.PathRewrite != nil { opts.PathRewrite = PathRewriteMode(*o.PathRewrite) } @@ -408,15 +452,49 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro s.Domain = req.Domain s.AccountID = accountID - targets := make([]*Target, 0, len(req.Targets)) - for i, apiTarget := range req.Targets { + if req.Mode != nil { + s.Mode = string(*req.Mode) + } + if req.ListenPort != nil { + s.ListenPort = uint16(*req.ListenPort) //nolint:gosec + } + + targets, err := targetsFromAPI(accountID, req.Targets) + if err != nil { + return err + } + s.Targets = targets + s.Enabled = req.Enabled + + if req.PassHostHeader != nil { + s.PassHostHeader = *req.PassHostHeader + } + if req.RewriteRedirects != nil { + s.RewriteRedirects = *req.RewriteRedirects + } + + if req.Auth != nil { + s.Auth = authFromAPI(req.Auth) + } + + return nil +} + +func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) { + var apiTargets []api.ServiceTarget + if apiTargetsPtr != nil { + apiTargets = *apiTargetsPtr + } + + targets := make([]*Target, 0, len(apiTargets)) + for i, apiTarget := range apiTargets { target := &Target{ AccountID: accountID, Path: apiTarget.Path, - Port: apiTarget.Port, + Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer Protocol: string(apiTarget.Protocol), TargetId: apiTarget.TargetId, - TargetType: string(apiTarget.TargetType), + TargetType: TargetType(apiTarget.TargetType), Enabled: apiTarget.Enabled, } if apiTarget.Host != nil { @@ -425,49 +503,42 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro if apiTarget.Options != nil { opts, err := targetOptionsFromAPI(i, apiTarget.Options) if err != nil { - return err + return nil, err } target.Options = opts + if apiTarget.Options.ProxyProtocol != nil { + target.ProxyProtocol = *apiTarget.Options.ProxyProtocol + } } targets = append(targets, target) } - s.Targets = targets + return targets, nil +} - s.Enabled = req.Enabled - - if req.PassHostHeader != nil { - s.PassHostHeader = *req.PassHostHeader - } - - if req.RewriteRedirects != nil { - s.RewriteRedirects = *req.RewriteRedirects - } - - if req.Auth.PasswordAuth != nil { - s.Auth.PasswordAuth = &PasswordAuthConfig{ - Enabled: req.Auth.PasswordAuth.Enabled, - Password: req.Auth.PasswordAuth.Password, +func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { + var auth AuthConfig + if reqAuth.PasswordAuth != nil { + auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: reqAuth.PasswordAuth.Enabled, + Password: reqAuth.PasswordAuth.Password, } } - - if req.Auth.PinAuth != nil { - s.Auth.PinAuth = &PINAuthConfig{ - Enabled: req.Auth.PinAuth.Enabled, - Pin: req.Auth.PinAuth.Pin, + if reqAuth.PinAuth != nil { + auth.PinAuth = &PINAuthConfig{ + Enabled: reqAuth.PinAuth.Enabled, + Pin: reqAuth.PinAuth.Pin, } } - - if req.Auth.BearerAuth != nil { + if reqAuth.BearerAuth != nil { bearerAuth := &BearerAuthConfig{ - Enabled: req.Auth.BearerAuth.Enabled, + Enabled: reqAuth.BearerAuth.Enabled, } - if req.Auth.BearerAuth.DistributionGroups != nil { - bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups + if reqAuth.BearerAuth.DistributionGroups != nil { + bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups } - s.Auth.BearerAuth = bearerAuth + auth.BearerAuth = bearerAuth } - - return nil + return auth } func (s *Service) Validate() error { @@ -478,14 +549,69 @@ func (s *Service) Validate() error { return errors.New("service name exceeds maximum length of 255 characters") } - if s.Domain == "" { - return errors.New("service domain is required") - } - if len(s.Targets) == 0 { return errors.New("at least one target is required") } + if s.Mode == "" { + s.Mode = ModeHTTP + } + + switch s.Mode { + case ModeHTTP: + return s.validateHTTPMode() + case ModeTCP, ModeUDP: + return s.validateTCPUDPMode() + case ModeTLS: + return s.validateTLSMode() + default: + return fmt.Errorf("unsupported mode %q", s.Mode) + } +} + +func (s *Service) validateHTTPMode() error { + if s.Domain == "" { + return errors.New("service domain is required") + } + if s.ListenPort != 0 { + return errors.New("listen_port is not supported for HTTP services") + } + return s.validateHTTPTargets() +} + +func (s *Service) validateTCPUDPMode() error { + if s.Domain == "" { + return errors.New("domain is required for TCP/UDP services (used for cluster derivation)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TCP/UDP services") + } + if len(s.Targets) != 1 { + return errors.New("TCP/UDP services must have exactly one target") + } + if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol { + return errors.New("proxy_protocol is not supported for UDP services") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateTLSMode() error { + if s.Domain == "" { + return errors.New("domain is required for TLS services (used for SNI matching)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TLS services") + } + if s.ListenPort == 0 { + return errors.New("listen_port is required for TLS services") + } + if len(s.Targets) != 1 { + return errors.New("TLS services must have exactly one target") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateHTTPTargets() error { for i, target := range s.Targets { switch target.TargetType { case TargetTypePeer, TargetTypeHost, TargetTypeDomain: @@ -500,6 +626,9 @@ func (s *Service) Validate() error { if target.TargetId == "" { return fmt.Errorf("target %d has empty target_id", i) } + if target.ProxyProtocol { + return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i) + } if err := validateTargetOptions(i, &target.Options); err != nil { return err } @@ -508,11 +637,62 @@ func (s *Service) Validate() error { return nil } +func (s *Service) validateL4Target(target *Target) error { + if target.Port == 0 { + return errors.New("target port is required for L4 services") + } + if target.TargetId == "" { + return errors.New("target_id is required for L4 services") + } + switch target.TargetType { + case TargetTypePeer, TargetTypeHost: + // OK + case TargetTypeSubnet: + if target.Host == "" { + return errors.New("target host is required for subnet targets") + } + default: + return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) + } + if target.Path != nil && *target.Path != "" && *target.Path != "/" { + return errors.New("path is not supported for L4 services") + } + return nil +} + +// Service mode constants. const ( - maxRequestTimeout = 5 * time.Minute - maxCustomHeaders = 16 - maxHeaderKeyLen = 128 - maxHeaderValueLen = 4096 + ModeHTTP = "http" + ModeTCP = "tcp" + ModeUDP = "udp" + ModeTLS = "tls" +) + +// Target protocol constants (URL scheme for backend connections). +const ( + TargetProtoHTTP = "http" + TargetProtoHTTPS = "https" + TargetProtoTCP = "tcp" + TargetProtoUDP = "udp" +) + +// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS). +func IsL4Protocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS +} + +// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation. +// TLS is excluded because it uses SNI routing and can share ports with other TLS services. +func IsPortBasedProtocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP +} + +const ( + maxRequestTimeout = 5 * time.Minute + maxSessionIdleTimeout = 10 * time.Minute + maxCustomHeaders = 16 + maxHeaderKeyLen = 128 + maxHeaderValueLen = 4096 ) // httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. @@ -560,6 +740,15 @@ func validateTargetOptions(idx int, opts *TargetOptions) error { } } + if opts.SessionIdleTimeout != 0 { + if opts.SessionIdleTimeout <= 0 { + return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) + } + if opts.SessionIdleTimeout > maxSessionIdleTimeout { + return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout) + } + } + if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { return err } @@ -608,17 +797,49 @@ func containsCRLF(s string) bool { } func (s *Service) EventMeta() map[string]any { - return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()} + meta := map[string]any{ + "name": s.Name, + "domain": s.Domain, + "proxy_cluster": s.ProxyCluster, + "source": s.Source, + "auth": s.isAuthEnabled(), + "mode": s.Mode, + } + + if s.ListenPort != 0 { + meta["listen_port"] = s.ListenPort + } + + if len(s.Targets) > 0 { + t := s.Targets[0] + if t.ProxyProtocol { + meta["proxy_protocol"] = true + } + if t.Options.RequestTimeout != 0 { + meta["request_timeout"] = t.Options.RequestTimeout.String() + } + if t.Options.SessionIdleTimeout != 0 { + meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String() + } + } + + return meta } func (s *Service) isAuthEnabled() bool { - return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil + return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || + (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || + (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) } func (s *Service) Copy() *Service { targets := make([]*Target, len(s.Targets)) for i, target := range s.Targets { targetCopy := *target + if target.Path != nil { + p := *target.Path + targetCopy.Path = &p + } if len(target.Options.CustomHeaders) > 0 { targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders)) for k, v := range target.Options.CustomHeaders { @@ -628,6 +849,24 @@ func (s *Service) Copy() *Service { targets[i] = &targetCopy } + authCopy := s.Auth + if s.Auth.PasswordAuth != nil { + pa := *s.Auth.PasswordAuth + authCopy.PasswordAuth = &pa + } + if s.Auth.PinAuth != nil { + pa := *s.Auth.PinAuth + authCopy.PinAuth = &pa + } + if s.Auth.BearerAuth != nil { + ba := *s.Auth.BearerAuth + if len(s.Auth.BearerAuth.DistributionGroups) > 0 { + ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups)) + copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups) + } + authCopy.BearerAuth = &ba + } + return &Service{ ID: s.ID, AccountID: s.AccountID, @@ -638,12 +877,15 @@ func (s *Service) Copy() *Service { Enabled: s.Enabled, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, - Auth: s.Auth, + Auth: authCopy, Meta: s.Meta, SessionPrivateKey: s.SessionPrivateKey, SessionPublicKey: s.SessionPublicKey, Source: s.Source, SourcePeer: s.SourcePeer, + Mode: s.Mode, + ListenPort: s.ListenPort, + PortAutoAssigned: s.PortAutoAssigned, } } @@ -688,12 +930,16 @@ var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`) // ExposeServiceRequest contains the parameters for creating a peer-initiated expose service. type ExposeServiceRequest struct { NamePrefix string - Port int - Protocol string - Domain string - Pin string - Password string - UserGroups []string + Port uint16 + Mode string + // TargetProtocol is the protocol used to connect to the peer backend. + // For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp". + TargetProtocol string + Domain string + Pin string + Password string + UserGroups []string + ListenPort uint16 } // Validate checks all fields of the expose request. @@ -702,12 +948,20 @@ func (r *ExposeServiceRequest) Validate() error { return errors.New("request cannot be nil") } - if r.Port < 1 || r.Port > 65535 { + if r.Port == 0 { return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port) } - if r.Protocol != "http" && r.Protocol != "https" { - return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol) + switch r.Mode { + case ModeHTTP, ModeTCP, ModeUDP, ModeTLS: + default: + return fmt.Errorf("unsupported mode %q", r.Mode) + } + + if IsL4Protocol(r.Mode) { + if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 { + return fmt.Errorf("authentication is not supported for %s mode", r.Mode) + } } if r.Pin != "" && !pinRegexp.MatchString(r.Pin) { @@ -729,55 +983,79 @@ func (r *ExposeServiceRequest) Validate() error { // ToService builds a Service from the expose request. func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service { - service := &Service{ + svc := &Service{ AccountID: accountID, Name: serviceName, + Mode: r.Mode, Enabled: true, - Targets: []*Target{ - { - AccountID: accountID, - Port: r.Port, - Protocol: r.Protocol, - TargetId: peerID, - TargetType: TargetTypePeer, - Enabled: true, - }, + } + + // If domain is empty, CreateServiceFromPeer generates a unique subdomain. + // When explicitly provided, the service name is prepended as a subdomain. + if r.Domain != "" { + svc.Domain = serviceName + "." + r.Domain + } + + if IsL4Protocol(r.Mode) { + svc.ListenPort = r.Port + if r.ListenPort > 0 { + svc.ListenPort = r.ListenPort + } + } + + var targetProto string + switch { + case !IsL4Protocol(r.Mode): + targetProto = TargetProtoHTTP + if r.TargetProtocol != "" { + targetProto = r.TargetProtocol + } + case r.Mode == ModeUDP: + targetProto = TargetProtoUDP + default: + targetProto = TargetProtoTCP + } + svc.Targets = []*Target{ + { + AccountID: accountID, + Port: r.Port, + Protocol: targetProto, + TargetId: peerID, + TargetType: TargetTypePeer, + Enabled: true, }, } - if r.Domain != "" { - service.Domain = serviceName + "." + r.Domain - } - if r.Pin != "" { - service.Auth.PinAuth = &PINAuthConfig{ + svc.Auth.PinAuth = &PINAuthConfig{ Enabled: true, Pin: r.Pin, } } if r.Password != "" { - service.Auth.PasswordAuth = &PasswordAuthConfig{ + svc.Auth.PasswordAuth = &PasswordAuthConfig{ Enabled: true, Password: r.Password, } } if len(r.UserGroups) > 0 { - service.Auth.BearerAuth = &BearerAuthConfig{ + svc.Auth.BearerAuth = &BearerAuthConfig{ Enabled: true, DistributionGroups: r.UserGroups, } } - return service + return svc } // ExposeServiceResponse contains the result of a successful peer expose creation. type ExposeServiceResponse struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } // GenerateExposeName generates a random service name for peer-exposed services. diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 79c98fc14..a8a8ae5d6 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) { func TestValidate_NoTargets(t *testing.T) { rp := validProxy() rp.Targets = nil - assert.ErrorContains(t, rp.Validate(), "at least one target") + assert.ErrorContains(t, rp.Validate(), "at least one target is required") } func TestValidate_EmptyTargetId(t *testing.T) { @@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) { func TestIsDefaultPort(t *testing.T) { tests := []struct { scheme string - port int + port uint16 want bool }{ {"http", 80, true}, @@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { name string protocol string host string - port int + port uint16 wantTarget string }{ { @@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) { func TestExposeServiceRequest_ToService(t *testing.T) { t.Run("basic HTTP service", func(t *testing.T) { req := &ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } service := req.ToService("account-1", "peer-1", "mysvc") @@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.Len(t, service.Targets, 1) target := service.Targets[0] - assert.Equal(t, 8080, target.Port) + assert.Equal(t, uint16(8080), target.Port) assert.Equal(t, "http", target.Protocol) assert.Equal(t, "peer-1", target.TargetId) assert.Equal(t, TargetTypePeer, target.TargetType) @@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.NotNil(t, service.Auth.BearerAuth) }) } + +func TestValidate_TLSOnly(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TLSMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 0, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "listen_port is required") +} + +func TestValidate_TLSMissingDomain(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "domain is required") +} + +func TestValidate_TCPValid(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TCPMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)") +} + +func TestValidate_L4MultipleTargets(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + {TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "exactly one target") +} + +func TestValidate_L4TargetMissingPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "port is required") +} + +func TestValidate_TLSInvalidTargetType(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.Error(t, rp.Validate()) +} + +func TestValidate_TLSSubnetValid(t *testing.T) { + rp := &Service{ + Name: "tls-subnet", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { + rp := validProxy() + rp.Targets[0].ProxyProtocol = true + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP") +} + +func TestValidate_UDPProxyProtocolRejected(t *testing.T) { + rp := &Service{ + Name: "udp-svc", + Mode: "udp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP") +} + +func TestValidate_TCPProxyProtocolAllowed(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) { + tests := []struct { + name string + req ExposeServiceRequest + }{ + { + name: "tcp with pin", + req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"}, + }, + { + name: "udp with password", + req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"}, + }, + { + name: "tls with user groups", + req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") + }) + } +} + +func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) { + req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"} + require.NoError(t, req.Validate()) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eb13a15e3..88d37ca80 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 29a8953ac..a32cf6046 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" @@ -211,6 +212,9 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) + s.AfterInit(func(s *BaseServer) { + m.SetClusterCapabilities(s.ServiceProxyController()) + }) return &m }) } diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go index c444471b0..1b87f7ede 100644 --- a/management/internals/shared/grpc/expose_service.go +++ b/management/internals/shared/grpc/expose_service.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "fmt" pb "github.com/golang/protobuf/proto" // nolint log "github.com/sirupsen/logrus" @@ -39,23 +40,38 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } + if exposeReq.Port > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port) + } + if exposeReq.ListenPort > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort) + } + + mode, err := exposeProtocolToString(exposeReq.Protocol) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{ - NamePrefix: exposeReq.NamePrefix, - Port: int(exposeReq.Port), - Protocol: exposeProtocolToString(exposeReq.Protocol), - Domain: exposeReq.Domain, - Pin: exposeReq.Pin, - Password: exposeReq.Password, - UserGroups: exposeReq.UserGroups, + NamePrefix: exposeReq.NamePrefix, + Port: uint16(exposeReq.Port), //nolint:gosec // validated above + Mode: mode, + TargetProtocol: exposeTargetProtocol(exposeReq.Protocol), + Domain: exposeReq.Domain, + Pin: exposeReq.Pin, + Password: exposeReq.Password, + UserGroups: exposeReq.UserGroups, + ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above }) if err != nil { return nil, mapExposeError(ctx, err) } return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{ - ServiceName: created.ServiceName, - ServiceUrl: created.ServiceURL, - Domain: created.Domain, + ServiceName: created.ServiceName, + ServiceUrl: created.ServiceURL, + Domain: created.Domain, + PortAutoAssigned: created.PortAutoAssigned, }) } @@ -77,7 +93,12 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) ( return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, renewReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -102,7 +123,12 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (* return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, stopReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -180,13 +206,46 @@ func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) { s.reverseProxyManager = mgr } -func exposeProtocolToString(p proto.ExposeProtocol) string { +// resolveServiceID looks up the service by its globally unique domain. +func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) { + if domain == "" { + return "", status.Errorf(codes.InvalidArgument, "domain is required") + } + + svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain) + if err != nil { + return "", err + } + return svc.ID, nil +} + +func exposeProtocolToString(p proto.ExposeProtocol) (string, error) { switch p { - case proto.ExposeProtocol_EXPOSE_HTTP: - return "http" - case proto.ExposeProtocol_EXPOSE_HTTPS: - return "https" + case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS: + return "http", nil + case proto.ExposeProtocol_EXPOSE_TCP: + return "tcp", nil + case proto.ExposeProtocol_EXPOSE_UDP: + return "udp", nil + case proto.ExposeProtocol_EXPOSE_TLS: + return "tls", nil default: - return "http" + return "", fmt.Errorf("unsupported expose protocol: %v", p) + } +} + +// exposeTargetProtocol returns the target protocol for the given expose protocol. +// For HTTP mode, this is http or https (the scheme used to connect to the backend). +// For L4 modes, this is tcp or udp (the transport used to connect to the backend). +func exposeTargetProtocol(p proto.ExposeProtocol) string { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTPS: + return rpservice.TargetProtoHTTPS + case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS: + return rpservice.TargetProtoTCP + case proto.ExposeProtocol_EXPOSE_UDP: + return rpservice.TargetProtoUDP + default: + return rpservice.TargetProtoHTTP } } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index e2d0f1abe..31a0ba0db 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -32,6 +32,7 @@ import ( proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/management/proto" + nbstatus "github.com/netbirdio/netbird/shared/management/status" ) type ProxyOIDCConfig struct { @@ -45,12 +46,6 @@ type ProxyOIDCConfig struct { KeysLocation string } -// ClusterInfo contains information about a proxy cluster. -type ClusterInfo struct { - Address string - ConnectedProxies int -} - // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -61,9 +56,9 @@ type ProxyServiceServer struct { // Manager for access logs accessLogManager accesslogs.Manager + mu sync.RWMutex // Manager for reverse proxy operations serviceManager rpservice.Manager - // ProxyController for service updates and cluster management proxyController proxy.Controller @@ -84,23 +79,26 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { - proxyID string - address string - stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + proxyID string + address string + capabilities *proto.ProxyCapabilities + stream proto.ProxyService_GetMappingUpdateServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } // NewProxyServiceServer creates a new proxy service server. func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, oidcConfig: oidcConfig, @@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + cancel: cancel, } go s.cleanupStaleProxies(ctx) return s @@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { } } +// Close stops background goroutines. +func (s *ProxyServiceServer) Close() { + s.cancel() +} + +// SetServiceManager sets the service manager. Must be called before serving. func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { + s.mu.Lock() + defer s.mu.Unlock() s.serviceManager = manager } +// SetProxyController sets the proxy controller. Must be called before serving. func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { + s.mu.Lock() + defer s.mu.Unlock() s.proxyController = proxyController } @@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ - proxyID: proxyID, - address: proxyAddress, - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, + proxyID: proxyID, + address: proxyAddress, + capabilities: req.GetCapabilities(), + stream: stream, + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), + ctx: connCtx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) @@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { } // sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only services matching the proxy's cluster address are sent. +// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return fmt.Errorf("get services from store: %w", err) - } - if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } - var filtered []*rpservice.Service - for _, service := range services { - if !service.Enabled { - continue - } - if service.ProxyCluster == "" || service.ProxyCluster != conn.address { - continue - } - filtered = append(filtered, service) + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err } - if len(filtered) == 0 { + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { @@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } - for i, service := range filtered { - // Generate one-time authentication token for each service in the snapshot - // Tokens are not persistent on the proxy, so we need to generate new ones on reconnection + for i, m := range mappings { + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{m}, + InitialSyncComplete: i == len(mappings)-1, + }); err != nil { + return fmt.Errorf("send proxy mapping: %w", err) + } + } + + return nil +} + +func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) + if err != nil { + return nil, fmt.Errorf("get services from store: %w", err) + } + + var mappings []*proto.ProxyMapping + for _, service := range services { + if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { + continue + } + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) if err != nil { log.WithFields(log.Fields{ @@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec continue } - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{ - service.ToProtoMapping( - rpservice.Create, // Initial snapshot, all records are "new" for the proxy. - token, - s.GetOIDCValidationConfig(), - ), - }, - InitialSyncComplete: i == len(filtered)-1, - }); err != nil { - log.WithFields(log.Fields{ - "domain": service.Domain, - "account": service.AccountID, - }).WithError(err).Error("failed to send proxy mapping") - return fmt.Errorf("send proxy mapping: %w", err) - } + m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + mappings = append(mappings, m) } - - return nil + return mappings, nil } // isProxyAddressValid validates a proxy address @@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool { func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { - case msg := <-conn.sendChan: - if err := conn.stream.Send(msg); err != nil { + case resp := <-conn.sendChan: + if err := conn.stream.Send(resp); err != nil { errChan <- err return } @@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - msg := s.perProxyMessage(update, conn.proxyID) - if msg == nil { + resp := s.perProxyMessage(update, conn.proxyID) + if resp == nil { return true } select { - case conn.sendChan <- msg: + case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) @@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { Auth: m.Auth, PassHostHeader: m.PassHostHeader, RewriteRedirects: m.RewriteRedirects, + Mode: m.Mode, + ListenPort: m.ListenPort, } } +// ClusterSupportsCustomPorts returns whether any connected proxy in the given +// cluster reports custom port support. Returns nil if no proxy has reported +// capabilities (old proxies that predate the field). +func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { + if s.proxyController == nil { + return nil + } + + var hasCapabilities bool + for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { + connVal, ok := s.connectedProxies.Load(pid) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { + continue + } + if *conn.capabilities.SupportsCustomPorts { + return ptr(true) + } + hasCapabilities = true + } + if hasCapabilities { + return ptr(false) + } + return nil +} + func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { @@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic return token, nil } -// SendStatusUpdate handles status updates from proxy clients +// SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { accountID := req.GetAccountId() serviceID := req.GetServiceId() @@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required") } + internalStatus := protoStatusToInternal(protoStatus) + + if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { + sErr, isNbErr := nbstatus.FromError(err) + if isNbErr && sErr.Type() == nbstatus.NotFound { + return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID) + } + log.WithContext(ctx).WithError(err).Error("failed to update service status") + return nil, status.Errorf(codes.Internal, "update service status: %v", err) + } + if certificateIssued { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") @@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se }).Info("Certificate issued timestamp updated") } - internalStatus := protoStatusToInternal(protoStatus) - - if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { - log.WithContext(ctx).WithError(err).Error("failed to update service status") - return nil, status.Errorf(codes.Internal, "update service status: %v", err) - } - log.WithFields(log.Fields{ "service_id": serviceID, "account_id": accountID, @@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return &proto.SendStatusUpdateResponse{}, nil } -// protoStatusToInternal maps proto status to internal status +// protoStatusToInternal maps proto status to internal service status. func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { switch protoStatus { case proto.ProxyStatus_PROXY_STATUS_PENDING: @@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user * return fmt.Errorf("user not in allowed groups") } + +func ptr[T any](v T) *T { return &v } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index b7abb28b6..1a4ea3330 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -53,6 +53,10 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return ptr(true) +} + func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -70,11 +74,17 @@ func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil) +} + +// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. +func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ - proxyID: proxyID, - address: clusterAddr, - sendChan: ch, + proxyID: proxyID, + address: clusterAddr, + capabilities: caps, + sendChan: ch, } s.connectedProxies.Store(proxyID, conn) @@ -83,15 +93,29 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan return ch } -func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { +// drainMapping drains a single ProxyMapping from the channel. +func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping { select { - case msg := <-ch: - return msg + case resp := <-ch: + if len(resp.Mapping) > 0 { + return resp.Mapping[0] + } + return nil case <-time.After(time.Second): return nil } } +// drainEmpty checks if a channel has no message within timeout. +func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { + select { + case <-ch: + return false + case <-time.After(100 * time.Millisecond): + return true + } +} + func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) @@ -129,10 +153,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { tokens := make([]string, numProxies) for i, ch := range channels { - resp := drainChannel(ch) - require.NotNil(t, resp, "proxy %d should receive a message", i) - require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i) - msg := resp.Mapping[0] + msg := drainMapping(ch) + require.NotNil(t, msg, "proxy %d should receive a message", i) assert.Equal(t, mapping.Domain, msg.Domain) assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) @@ -181,16 +203,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) // Delete operations should not generate tokens - assert.Empty(t, resp1.Mapping[0].AuthToken) - assert.Empty(t, resp2.Mapping[0].AuthToken) + assert.Empty(t, msg1.AuthToken) + assert.Empty(t, msg2.AuthToken) } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { @@ -224,15 +244,10 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { s.SendServiceUpdate(update) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) - - msg1 := resp1.Mapping[0] - msg2 := resp2.Mapping[0] + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg2.AuthToken) @@ -324,3 +339,314 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } + +func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + // Proxy A supports custom ports. + chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Proxy B does NOT support custom ports (shared cloud proxy). + chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + ctx := context.Background() + + // TLS passthrough works on all proxies regardless of custom port support. + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) + + msgA := drainMapping(chA) + msgB := drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") + assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + + // Send an HTTP mapping: both should receive it. + httpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-http", + AccountId: "account-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}}, + } + + s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) + + msgA = drainMapping(chA) + msgB = drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") + assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") +} + +func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) + + msg := drainMapping(chShared) + assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") +} + +// TestServiceModifyNotifications exercises every possible modification +// scenario for an existing service, verifying the correct update types +// reach the correct clusters. +func TestServiceModifyNotifications(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + chs := map[string]chan *proto.GetMappingUpdateResponse{ + "cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + "cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + } + return s, chs + } + + httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}}, + } + } + + tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}}, + } + } + + ctx := context.Background() + + t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg, "cluster-a should receive update") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token") + assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update") + }) + + t.Run("auth config changed sends MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.Auth = &proto.Authentication{Password: true, Pin: true} + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.True(t, msg.Auth.Password) + assert.True(t, msg.Auth.Pin) + }) + + t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Equal(t, "tls", msg.Mode, "mode should be tls") + assert.Equal(t, int32(8443), msg.ListenPort) + assert.Len(t, msg.Path, 1, "should have one path entry with target address") + assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target) + }) + + t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Empty(t, msg.Mode, "mode should be empty for HTTP") + assert.True(t, len(msg.Path) > 0) + }) + + t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) { + s, chs := newServer() + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 9443 + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, int32(9443), msg.ListenPort) + }) + + t.Run("disable sends REMOVED to cluster", func(t *testing.T) { + s, chs := newServer() + // Manager sends Delete when service is disabled + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type) + assert.Empty(t, msg.AuthToken, "DELETE should not have token") + }) + + t.Run("enable sends CREATED to cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type) + assert.NotEmpty(t, msg.AuthToken) + }) + + t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) { + s, chs := newServer() + // This is the pattern the manager produces: + // 1. DELETE on old cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + // 2. CREATE on new cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b") + + msgA := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgA, "old cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type) + + msgB := drainMapping(chs["cluster-b"]) + require.NotNil(t, msgB, "new cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type) + assert.NotEmpty(t, msgB.AuthToken) + }) + + t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) { + s, chs := newServer() + // Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain). + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msgDel := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgDel, "same cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type) + + msgCreate := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgCreate, "same cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type) + assert.NotEmpty(t, msgCreate.AuthToken) + }) + + t.Run("TLS passthrough sent to all proxies", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + // TLS passthrough works on all proxies regardless of custom port support + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) + + msgModern := drainMapping(chModern) + require.NotNil(t, msgModern, "modern proxy receives TLS update") + assert.Equal(t, "tls", msgModern.Mode) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough") + }) + + t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 0 // default port + s.SendServiceUpdateToCluster(ctx, mapping, cluster) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port") + }) + + t.Run("passthrough and rewrite flags propagated", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.PassHostHeader = true + mapping.RewriteRedirects = true + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.True(t, msg.PassHostHeader) + assert.True(t, msg.RewriteRedirects) + }) + + t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, + Id: "svc-multi", + AccountId: "acct-1", + Domain: "multi.example.com", + Path: []*proto.PathMapping{ + {Path: "/", Target: "http://10.0.0.1:8080"}, + {Path: "/api", Target: "http://10.0.0.2:9090"}, + {Path: "/ws", Target: "http://10.0.0.3:3000"}, + }, + } + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + require.Len(t, msg.Path, 3, "all paths should be present") + assert.Equal(t, "/", msg.Path[0].Path) + assert.Equal(t, "/api", msg.Path[1].Path) + assert.Equal(t, "/ws", msg.Path[2].Path) + }) +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ddeda6d7f..ad36b9d46 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -174,9 +174,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } - // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 462013963..6bd269a2c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric/noop" "github.com/netbirdio/management-integrations/integrations" + accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -113,6 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } + domainManager.SetClusterCapabilities(serviceProxyController) serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bfefce388..8732cf89f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -219,7 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { servicesStatusActive int servicesStatusPending int servicesStatusError int - servicesTargetType map[string]int + servicesTargetType map[rpservice.TargetType]int servicesAuthPassword int servicesAuthPin int servicesAuthOIDC int @@ -232,7 +232,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rulesDirection = make(map[string]int) activeUsersLastDay = make(map[string]struct{}) embeddedIdpTypes = make(map[string]int) - servicesTargetType = make(map[string]int) + servicesTargetType = make(map[rpservice.TargetType]int) uptime = time.Since(w.startupTime).Seconds() connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() @@ -434,7 +434,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["custom_domains_validated"] = customDomainsValidated for targetType, count := range servicesTargetType { - metricsProperties["services_target_type_"+targetType] = count + metricsProperties["services_target_type_"+string(targetType)] = count } for idpType, count := range embeddedIdpTypes { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 5997c10e2..b3fbfe141 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -30,6 +30,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -4996,6 +4997,7 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpse return service, nil } + func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { @@ -5041,16 +5043,16 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS } // RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service. -func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { result := s.db.Model(&rpservice.Service{}). - Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral). Update("meta_last_renewed_at", time.Now()) if result.Error != nil { log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error) return status.Errorf(status.Internal, "renew ephemeral service") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) + return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID) } return nil } @@ -5133,6 +5135,37 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock return id != "", nil } +// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. +func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster and port") + } + + return services, nil +} + +// GetServicesByCluster returns all services for the given proxy cluster. +func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster") + } + return services, nil +} + func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 1fa99fd05..8bb52f38a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -261,10 +261,12 @@ type Store interface { GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) - RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error + RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) + GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) + GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 130df4485..e75e35b94 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1991,6 +1991,36 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength) } +// GetServicesByCluster mocks base method. +func (m *MockStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByCluster", ctx, lockStrength, proxyCluster) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByCluster indicates an expected call of GetServicesByCluster. +func (mr *MockStoreMockRecorder) GetServicesByCluster(ctx, lockStrength, proxyCluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByCluster", reflect.TypeOf((*MockStore)(nil).GetServicesByCluster), ctx, lockStrength, proxyCluster) +} + +// GetServicesByClusterAndPort mocks base method. +func (m *MockStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster, mode string, listenPort uint16) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByClusterAndPort", ctx, lockStrength, proxyCluster, mode, listenPort) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByClusterAndPort indicates an expected call of GetServicesByClusterAndPort. +func (mr *MockStoreMockRecorder) GetServicesByClusterAndPort(ctx, lockStrength, proxyCluster, mode, listenPort interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByClusterAndPort", reflect.TypeOf((*MockStore)(nil).GetServicesByClusterAndPort), ctx, lockStrength, proxyCluster, mode, listenPort) +} + // GetSetupKeyByID mocks base method. func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) { m.ctrl.T.Helper() @@ -2447,17 +2477,17 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID } // RenewEphemeralService mocks base method. -func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewEphemeralService indicates an expected call of RenewEphemeralService. -func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, serviceID) } // RevokeProxyAccessToken mocks base method. diff --git a/management/server/types/account.go b/management/server/types/account.go index 6145ceeb2..269fc7a88 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -907,8 +907,8 @@ func (a *Account) Copy() *Account { } services := []*service.Service{} - for _, service := range a.Services { - services = append(services, service.Copy()) + for _, svc := range a.Services { + services = append(services, svc.Copy()) } return &Account{ @@ -1605,12 +1605,12 @@ func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { networkResourceGroups := a.getNetworkResourceGroups(resourceId) for _, policy := range a.Policies { - if !policy.Enabled { + if policy == nil || !policy.Enabled { continue } for _, rule := range policy.Rules { - if !rule.Enabled { + if rule == nil || !rule.Enabled { continue } @@ -1812,15 +1812,18 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) { } a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster) } + } func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { + proxyPeers := proxyPeersByCluster[service.ProxyCluster] for _, target := range service.Targets { if !target.Enabled { continue } - a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster]) + a.injectTargetProxyPolicies(ctx, service, target, proxyPeers) } + } func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { @@ -1840,13 +1843,13 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *servic } } -func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) { +func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (uint16, bool) { if target.Port != 0 { return target.Port, true } switch target.Protocol { - case "https": + case "https", "tls": return 443, true case "http": return 80, true @@ -1856,17 +1859,23 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) } } -func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { - policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) +func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port uint16, path string) *Policy { + policyID := fmt.Sprintf("proxy-access-%s-%s-%s", svc.ID, proxyPeer.ID, path) + + protocol := PolicyRuleProtocolTCP + if svc.Mode == service.ModeUDP { + protocol = PolicyRuleProtocolUDP + } + return &Policy{ ID: policyID, - Name: fmt.Sprintf("Proxy Access to %s", service.Name), + Name: fmt.Sprintf("Proxy Access to %s", svc.Name), Enabled: true, Rules: []*PolicyRule{ { ID: policyID, PolicyID: policyID, - Name: fmt.Sprintf("Allow access to %s", service.Name), + Name: fmt.Sprintf("Allow access to %s", svc.Name), Enabled: true, SourceResource: Resource{ ID: proxyPeer.ID, @@ -1877,12 +1886,12 @@ func (a *Account) createProxyPolicy(service *service.Service, target *service.Ta Type: ResourceType(target.TargetType), }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: protocol, Action: PolicyTrafficActionAccept, PortRanges: []RulePortRange{ { - Start: uint16(port), - End: uint16(port), + Start: port, + End: port, }, }, }, diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 60e81feb5..d82f5b7fc 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strconv" "syscall" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -34,30 +35,32 @@ var ( ) var ( - logLevel string - debugLogs bool - mgmtAddr string - addr string - proxyDomain string - certDir string - acmeCerts bool - acmeAddr string - acmeDir string - acmeEABKID string - acmeEABHMACKey string - acmeChallengeType string - debugEndpoint bool - debugEndpointAddr string - healthAddr string - forwardedProto string - trustedProxies string - certFile string - certKeyFile string - certLockMethod string - wildcardCertDir string - wgPort int - proxyProtocol bool - preSharedKey string + logLevel string + debugLogs bool + mgmtAddr string + addr string + proxyDomain string + defaultDialTimeout time.Duration + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + acmeEABKID string + acmeEABHMACKey string + acmeChallengeType string + debugEndpoint bool + debugEndpointAddr string + healthAddr string + forwardedProto string + trustedProxies string + certFile string + certKeyFile string + certLockMethod string + wildcardCertDir string + wgPort uint16 + proxyProtocol bool + preSharedKey string + supportsCustomPorts bool ) var rootCmd = &cobra.Command{ @@ -92,9 +95,11 @@ func init() { rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (.crt/.key). Wildcard patterns are extracted from SANs automatically") - rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") + rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") + rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") + rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)") } // Execute runs the root command. @@ -171,6 +176,8 @@ func runServer(cmd *cobra.Command, args []string) error { WireguardPort: wgPort, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, + SupportsCustomPorts: supportsCustomPorts, + DefaultDialTimeout: defaultDialTimeout, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) @@ -203,12 +210,24 @@ func envStringOrDefault(key string, def string) string { return v } -func envIntOrDefault(key string, def int) int { +func envUint16OrDefault(key string, def uint16) uint16 { v, exists := os.LookupEnv(key) if !exists { return def } - parsed, err := strconv.Atoi(v) + parsed, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return def + } + return uint16(parsed) +} + +func envDurationOrDefault(key string, def time.Duration) time.Duration { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + parsed, err := time.ParseDuration(v) if err != nil { return def } diff --git a/proxy/handle_mapping_stream_test.go b/proxy/handle_mapping_stream_test.go index d2ad3f67e..cb16c0814 100644 --- a/proxy/handle_mapping_stream_test.go +++ b/proxy/handle_mapping_stream_test.go @@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun func (m *mockMappingStream) SendMsg(any) error { return nil } func (m *mockMappingStream) RecvMsg(any) error { return nil } +func closedChan() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) { checker := health.NewChecker(nil, nil) s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { func TestHandleMappingStream_NilHealthChecker(t *testing.T) { s := &Server{ - Logger: log.StandardLogger(), + Logger: log.StandardLogger(), + routerReady: closedChan(), } stream := &mockMappingStream{ diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 4ba5a7755..5b05ab195 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -6,11 +6,13 @@ import ( "sync" "time" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -19,6 +21,7 @@ const ( bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours + logSendTimeout = 10 * time.Second ) type domainUsage struct { @@ -79,22 +82,63 @@ func (l *Logger) Close() { type logEntry struct { ID string - AccountID string - ServiceId string + AccountID types.AccountID + ServiceId types.ServiceID Host string Path string DurationMs int64 Method string ResponseCode int32 - SourceIp string + SourceIP netip.Addr AuthMechanism string UserId string AuthSuccess bool BytesUpload int64 BytesDownload int64 + Protocol Protocol } -func (l *Logger) log(ctx context.Context, entry logEntry) { +// Protocol identifies the transport protocol of an access log entry. +type Protocol string + +const ( + ProtocolHTTP Protocol = "http" + ProtocolTCP Protocol = "tcp" + ProtocolUDP Protocol = "udp" + ProtocolTLS Protocol = "tls" +) + +// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry. +type L4Entry struct { + AccountID types.AccountID + ServiceID types.ServiceID + Protocol Protocol + Host string // SNI hostname or listen address + SourceIP netip.Addr + DurationMs int64 + BytesUpload int64 + BytesDownload int64 +} + +// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). +// The call is non-blocking: the gRPC send happens in a background goroutine. +func (l *Logger) LogL4(entry L4Entry) { + le := logEntry{ + ID: xid.New().String(), + AccountID: entry.AccountID, + ServiceId: entry.ServiceID, + Protocol: entry.Protocol, + Host: entry.Host, + SourceIP: entry.SourceIP, + DurationMs: entry.DurationMs, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, + } + l.log(le) + l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) +} + +func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message // (although it should still get logged in the event of an error), @@ -105,31 +149,37 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { // allow for resolving that on the server. now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. go func() { - logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) defer cancel() if entry.AuthMechanism != auth.MethodOIDC.String() { entry.UserId = "" } + + var sourceIP string + if entry.SourceIP.IsValid() { + sourceIP = entry.SourceIP.String() + } + if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{ Log: &proto.AccessLog{ LogId: entry.ID, - AccountId: entry.AccountID, + AccountId: string(entry.AccountID), Timestamp: now, - ServiceId: entry.ServiceId, + ServiceId: string(entry.ServiceId), Host: entry.Host, Path: entry.Path, DurationMs: entry.DurationMs, Method: entry.Method, ResponseCode: entry.ResponseCode, - SourceIp: entry.SourceIp, + SourceIp: sourceIP, AuthMechanism: entry.AuthMechanism, UserId: entry.UserId, AuthSuccess: entry.AuthSuccess, BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, + Protocol: string(entry.Protocol), }, }); err != nil { - // If it fails to send on the gRPC connection, then at least log it to the error log. l.logger.WithFields(log.Fields{ "service_id": entry.ServiceId, "host": entry.Host, @@ -137,7 +187,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { "duration": entry.DurationMs, "method": entry.Method, "response_code": entry.ResponseCode, - "source_ip": entry.SourceIp, + "source_ip": sourceIP, "auth_mechanism": entry.AuthMechanism, "user_id": entry.UserId, "auth_success": entry.AuthSuccess, diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 7368185c0..593a77ef2 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -67,23 +67,24 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: requestID, ServiceId: capturedData.GetServiceId(), - AccountID: string(capturedData.GetAccountId()), + AccountID: capturedData.GetAccountId(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), Method: r.Method, ResponseCode: int32(sw.status), - SourceIp: sourceIp, + SourceIP: sourceIp, AuthMechanism: capturedData.GetAuthMethod(), UserId: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, BytesUpload: bytesUpload, BytesDownload: bytesDownload, + Protocol: ProtocolHTTP, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) - l.log(r.Context(), entry) + l.log(entry) // Track usage for cost monitoring (upload + download) by domain l.trackUsage(host, bytesUpload+bytesDownload) diff --git a/proxy/internal/accesslog/requestip.go b/proxy/internal/accesslog/requestip.go index f111c1322..30c483fd9 100644 --- a/proxy/internal/accesslog/requestip.go +++ b/proxy/internal/accesslog/requestip.go @@ -11,6 +11,6 @@ import ( // proxy configuration. When trustedProxies is non-empty and the direct // connection is from a trusted source, it walks X-Forwarded-For right-to-left // skipping trusted IPs. Otherwise it returns RemoteAddr directly. -func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string { +func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr { return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies) } diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index 395da7d88..a4a220ed7 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -23,6 +23,7 @@ import ( "golang.org/x/crypto/acme/autocert" "github.com/netbirdio/netbird/proxy/internal/certwatch" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -30,7 +31,7 @@ import ( var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2} type certificateNotifier interface { - NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error + NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error } type domainState int @@ -42,8 +43,8 @@ const ( ) type domainInfo struct { - accountID string - serviceID string + accountID types.AccountID + serviceID types.ServiceID state domainState err string } @@ -301,7 +302,7 @@ func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate // When AddDomain returns true the caller is responsible for sending any // certificate-ready notifications after the surrounding operation (e.g. // mapping update) has committed successfully. -func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) (wildcardHit bool) { +func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) { name := d.PunycodeString() if e := mgr.findWildcardEntry(name); e != nil { mgr.mu.Lock() diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 9a3ed9efd..ceb9ca13a 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -17,12 +17,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" ) func TestHostPolicy(t *testing.T) { mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) require.NoError(t, err) - mgr.AddDomain("example.com", "acc1", "rp1") + mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1")) // Wait for the background prefetch goroutine to finish so the temp dir // can be cleaned up without a race. @@ -92,8 +94,8 @@ func TestDomainStates(t *testing.T) { // AddDomain starts as pending, then the prefetch goroutine will fail // (no real ACME server) and transition to failed. - mgr.AddDomain("a.example.com", "acc1", "rp1") - mgr.AddDomain("b.example.com", "acc1", "rp1") + mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) + mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered") @@ -209,12 +211,12 @@ func TestWildcardAddDomainSkipsACME(t *testing.T) { require.NoError(t, err) // Add a wildcard-matching domain — should be immediately ready. - mgr.AddDomain("foo.example.com", "acc1", "svc1") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending") assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) // Add a non-wildcard domain — should go through ACME (pending then failed). - mgr.AddDomain("other.net", "acc2", "svc2") + mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2")) assert.Equal(t, 2, mgr.TotalDomains()) // Wait for the ACME prefetch to fail. @@ -234,7 +236,7 @@ func TestWildcardGetCertificate(t *testing.T) { mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) require.NoError(t, err) - mgr.AddDomain("foo.example.com", "acc1", "svc1") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) // GetCertificate for a wildcard-matching domain should return the static cert. cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) @@ -255,8 +257,8 @@ func TestMultipleWildcards(t *testing.T) { assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns()) // Both wildcards should resolve. - mgr.AddDomain("foo.example.com", "acc1", "svc1") - mgr.AddDomain("bar.other.org", "acc2", "svc2") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2")) assert.Equal(t, 0, mgr.PendingCerts()) assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains()) @@ -271,7 +273,7 @@ func TestMultipleWildcards(t *testing.T) { assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org") // Non-matching domain falls through to ACME. - mgr.AddDomain("custom.net", "acc3", "svc3") + mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3")) assert.Eventually(t, func() bool { return mgr.PendingCerts() == 0 }, 30*time.Second, 100*time.Millisecond) diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 8a966faa3..3cf86e4b3 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -44,8 +44,8 @@ type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration - AccountID string - ServiceID string + AccountID types.AccountID + ServiceID types.ServiceID } type validationResult struct { @@ -124,7 +124,7 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { func setCapturedIDs(r *http.Request, config DomainConfig) { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(types.AccountID(config.AccountID)) + cd.SetAccountId(config.AccountID) cd.SetServiceId(config.ServiceID) } } @@ -275,7 +275,7 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index bf178d432..a60e6437a 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -17,14 +18,14 @@ type urlGenerator interface { } type OIDC struct { - id string - accountId string + id types.ServiceID + accountId types.AccountID forwardedProto string client urlGenerator } // NewOIDC creates a new OIDC authentication scheme -func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC { +func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC { return OIDC{ id: id, accountId: accountId, @@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) { } res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{ - Id: o.id, - AccountId: o.accountId, + Id: string(o.id), + AccountId: string(o.accountId), RedirectUrl: redirectURL.String(), }) if err != nil { diff --git a/proxy/internal/auth/password.go b/proxy/internal/auth/password.go index 208423465..6a7eda3e1 100644 --- a/proxy/internal/auth/password.go +++ b/proxy/internal/auth/password.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const passwordFormId = "password" type Password struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPassword(client authenticator, id, accountId string) Password { +func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password { return Password{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Password{ Password: &proto.PasswordRequest{ Password: password, diff --git a/proxy/internal/auth/pin.go b/proxy/internal/auth/pin.go index c1eb56071..4d08f3dc6 100644 --- a/proxy/internal/auth/pin.go +++ b/proxy/internal/auth/pin.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const pinFormId = "pin" type Pin struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPin(client authenticator, id, accountId string) Pin { +func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin { return Pin{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Pin{ Pin: &proto.PinRequest{ Pin: pin, diff --git a/proxy/internal/conntrack/conn.go b/proxy/internal/conntrack/conn.go index 97055d992..8446d638f 100644 --- a/proxy/internal/conntrack/conn.go +++ b/proxy/internal/conntrack/conn.go @@ -10,10 +10,11 @@ import ( type trackedConn struct { net.Conn tracker *HijackTracker + host string } func (c *trackedConn) Close() error { - c.tracker.conns.Delete(c) + c.tracker.remove(c) return c.Conn.Close() } @@ -22,6 +23,7 @@ func (c *trackedConn) Close() error { type trackingWriter struct { http.ResponseWriter tracker *HijackTracker + host string } func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - tc := &trackedConn{Conn: conn, tracker: w.tracker} - w.tracker.conns.Store(tc, struct{}{}) + tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host} + w.tracker.add(tc) return tc, buf, nil } diff --git a/proxy/internal/conntrack/hijacked.go b/proxy/internal/conntrack/hijacked.go index d76cebc08..911f93f3d 100644 --- a/proxy/internal/conntrack/hijacked.go +++ b/proxy/internal/conntrack/hijacked.go @@ -1,7 +1,6 @@ package conntrack import ( - "net" "net/http" "sync" ) @@ -10,10 +9,14 @@ import ( // upgrades). http.Server.Shutdown does not close hijacked connections, so // they must be tracked and closed explicitly during graceful shutdown. // +// Connections are indexed by the request Host so they can be closed +// per-domain when a service mapping is removed. +// // Use Middleware as the outermost HTTP middleware to ensure hijacked // connections are tracked and automatically deregistered when closed. type HijackTracker struct { - conns sync.Map // net.Conn → struct{} + mu sync.Mutex + conns map[*trackedConn]struct{} } // Middleware returns an HTTP middleware that wraps the ResponseWriter so that @@ -21,21 +24,73 @@ type HijackTracker struct { // tracker when closed. This should be the outermost middleware in the chain. func (t *HijackTracker) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r) + next.ServeHTTP(&trackingWriter{ + ResponseWriter: w, + tracker: t, + host: hostOnly(r.Host), + }, r) }) } -// CloseAll closes all tracked hijacked connections and returns the number -// of connections that were closed. +// CloseAll closes all tracked hijacked connections and returns the count. func (t *HijackTracker) CloseAll() int { - var count int - t.conns.Range(func(key, _ any) bool { - if conn, ok := key.(net.Conn); ok { - _ = conn.Close() - count++ - } - t.conns.Delete(key) - return true - }) - return count + t.mu.Lock() + conns := t.conns + t.conns = nil + t.mu.Unlock() + + for tc := range conns { + _ = tc.Conn.Close() + } + return len(conns) +} + +// CloseByHost closes all tracked hijacked connections for the given host +// and returns the number of connections closed. +func (t *HijackTracker) CloseByHost(host string) int { + host = hostOnly(host) + t.mu.Lock() + var toClose []*trackedConn + for tc := range t.conns { + if tc.host == host { + toClose = append(toClose, tc) + } + } + for _, tc := range toClose { + delete(t.conns, tc) + } + t.mu.Unlock() + + for _, tc := range toClose { + _ = tc.Conn.Close() + } + return len(toClose) +} + +func (t *HijackTracker) add(tc *trackedConn) { + t.mu.Lock() + if t.conns == nil { + t.conns = make(map[*trackedConn]struct{}) + } + t.conns[tc] = struct{}{} + t.mu.Unlock() +} + +func (t *HijackTracker) remove(tc *trackedConn) { + t.mu.Lock() + delete(t.conns, tc) + t.mu.Unlock() +} + +// hostOnly strips the port from a host:port string. +func hostOnly(hostport string) string { + for i := len(hostport) - 1; i >= 0; i-- { + if hostport[i] == ':' { + return hostport[:i] + } + if hostport[i] < '0' || hostport[i] > '9' { + return hostport + } + } + return hostport } diff --git a/proxy/internal/conntrack/hijacked_test.go b/proxy/internal/conntrack/hijacked_test.go new file mode 100644 index 000000000..9ceefff78 --- /dev/null +++ b/proxy/internal/conntrack/hijacked_test.go @@ -0,0 +1,142 @@ +package conntrack + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing. +type fakeHijackWriter struct { + http.ResponseWriter + conn net.Conn +} + +func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} + +func TestCloseByHost(t *testing.T) { + var tracker HijackTracker + + // Simulate hijacking two connections for different hosts. + connA1, connA2 := net.Pipe() + defer connA2.Close() + connB1, connB2 := net.Pipe() + defer connB2.Close() + + twA := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "a.example.com", + } + twB := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "b.example.com", + } + + // Use fakeHijackWriter to provide the Hijack method. + twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1} + twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1} + + _, _, err := twA.Hijack() + require.NoError(t, err) + _, _, err = twB.Hijack() + require.NoError(t, err) + + tracker.mu.Lock() + assert.Equal(t, 2, len(tracker.conns), "should track 2 connections") + tracker.mu.Unlock() + + // Close only host A. + n := tracker.CloseByHost("a.example.com") + assert.Equal(t, 1, n, "should close 1 connection for host A") + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection") + tracker.mu.Unlock() + + // Verify host A's conn is actually closed. + buf := make([]byte, 1) + _, err = connA2.Read(buf) + assert.Error(t, err, "host A pipe should be closed") + + // Host B should still be alive. + go func() { _, _ = connB1.Write([]byte("x")) }() + + // Close all remaining. + n = tracker.CloseAll() + assert.Equal(t, 1, n, "should close remaining 1 connection") + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll") + tracker.mu.Unlock() +} + +func TestCloseAll(t *testing.T) { + var tracker HijackTracker + + for range 5 { + c1, c2 := net.Pipe() + defer c2.Close() + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"} + tracker.add(tc) + } + + tracker.mu.Lock() + assert.Equal(t, 5, len(tracker.conns)) + tracker.mu.Unlock() + + n := tracker.CloseAll() + assert.Equal(t, 5, n) + + // Double CloseAll is safe. + n = tracker.CloseAll() + assert.Equal(t, 0, n) +} + +func TestTrackedConn_AutoDeregister(t *testing.T) { + var tracker HijackTracker + + c1, c2 := net.Pipe() + defer c2.Close() + + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"} + tracker.add(tc) + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns)) + tracker.mu.Unlock() + + // Close the tracked conn: should auto-deregister. + require.NoError(t, tc.Close()) + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close") + tracker.mu.Unlock() +} + +func TestHostOnly(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com:443", "example.com"}, + {"example.com", "example.com"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:443", "[::1]"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, hostOnly(tt.input)) + }) + } +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 885c574bc..01b0bc8e6 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) { return } - _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT") + _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT") _, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110)) for _, item := range clients { @@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) { return } - domains := c.extractDomains(client) + services := c.extractServiceKeys(client) hasClient := "no" if hc, ok := client["has_client"].(bool); ok && hc { hasClient = "yes" @@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) { _, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n", client["account_id"], client["age"], - domains, + services, hasClient, ) } -func (c *Client) extractDomains(client map[string]any) string { - d, ok := client["domains"].([]any) +func (c *Client) extractServiceKeys(client map[string]any) string { + d, ok := client["service_keys"].([]any) if !ok || len(d) == 0 { return "-" } parts := make([]string, len(d)) - for i, domain := range d { - parts[i] = fmt.Sprint(domain) + for i, key := range d { + parts[i] = fmt.Sprint(key) } return strings.Join(parts, ", ") } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index ab75c8b72..237010922 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -189,7 +189,7 @@ type indexData struct { Version string Uptime string ClientCount int - TotalDomains int + TotalServices int CertsTotal int CertsReady int CertsPending int @@ -202,7 +202,7 @@ type indexData struct { type clientData struct { AccountID string - Domains string + Services string Age string Status string } @@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b clients := h.provider.ListClientsForDebug() sortedIDs := sortedAccountIDs(clients) - totalDomains := 0 + totalServices := 0 for _, info := range clients { - totalDomains += info.DomainCount + totalServices += info.ServiceCount } var certsTotal, certsReady, certsPending, certsFailed int @@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } resp := map[string]interface{}{ - "version": version.NetbirdVersion(), - "uptime": time.Since(h.startTime).Round(time.Second).String(), - "client_count": len(clients), - "total_domains": totalDomains, - "certs_total": certsTotal, - "certs_ready": certsReady, - "certs_pending": certsPending, - "certs_failed": certsFailed, - "clients": clientsJSON, + "version": version.NetbirdVersion(), + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "total_services": totalServices, + "certs_total": certsTotal, + "certs_ready": certsReady, + "certs_pending": certsPending, + "certs_failed": certsFailed, + "clients": clientsJSON, } if len(certsPendingDomains) > 0 { resp["certs_pending_domains"] = certsPendingDomains @@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b Version: version.NetbirdVersion(), Uptime: time.Since(h.startTime).Round(time.Second).String(), ClientCount: len(clients), - TotalDomains: totalDomains, + TotalServices: totalServices, CertsTotal: certsTotal, CertsReady: certsReady, CertsPending: certsPending, @@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) @@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } h.writeJSON(w, map[string]interface{}{ @@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) diff --git a/proxy/internal/debug/templates/clients.html b/proxy/internal/debug/templates/clients.html index 4d455b2bb..bfc25f95a 100644 --- a/proxy/internal/debug/templates/clients.html +++ b/proxy/internal/debug/templates/clients.html @@ -12,14 +12,14 @@ - + {{range .Clients}} - + diff --git a/proxy/internal/debug/templates/index.html b/proxy/internal/debug/templates/index.html index 16ab3d979..5bd25adfc 100644 --- a/proxy/internal/debug/templates/index.html +++ b/proxy/internal/debug/templates/index.html @@ -27,19 +27,19 @@
    {{range .CertsFailedDomains}}
  • {{.Domain}}: {{.Error}}
  • {{end}}
{{end}} -

Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})

+

Clients ({{.ClientCount}}) | Services ({{.TotalServices}})

{{if .Clients}}
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}
- + {{range .Clients}} - + diff --git a/proxy/internal/metrics/l4_metrics_test.go b/proxy/internal/metrics/l4_metrics_test.go new file mode 100644 index 000000000..055158828 --- /dev/null +++ b/proxy/internal/metrics/l4_metrics_test.go @@ -0,0 +1,69 @@ +package metrics_test + +import ( + "context" + "reflect" + "testing" + "time" + + promexporter "go.opentelemetry.io/otel/exporters/prometheus" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func newTestMetrics(t *testing.T) *metrics.Metrics { + t.Helper() + + exporter, err := promexporter.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } + return m +} + +func TestL4ServiceGauge(t *testing.T) { + m := newTestMetrics(t) + + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeUDP) + m.L4ServiceRemoved(types.ServiceModeTCP) +} + +func TestTCPRelayMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-1") + + m.TCPRelayStarted(acct) + m.TCPRelayStarted(acct) + m.TCPRelayEnded(acct, 10*time.Second, 1000, 500) + m.TCPRelayDialError(acct) + m.TCPRelayRejected(acct) +} + +func TestUDPSessionMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-2") + + m.UDPSessionStarted(acct) + m.UDPSessionStarted(acct) + m.UDPSessionEnded(acct) + m.UDPSessionDialError(acct) + m.UDPSessionRejected(acct) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200) + m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150) +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 68ff55fe5..573485625 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -6,12 +6,15 @@ import ( "sync" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/responsewriter" + "github.com/netbirdio/netbird/proxy/internal/types" ) +// Metrics collects OpenTelemetry metrics for the proxy. type Metrics struct { ctx context.Context requestsTotal metric.Int64Counter @@ -22,85 +25,188 @@ type Metrics struct { backendDuration metric.Int64Histogram certificateIssueDuration metric.Int64Histogram + // L4 service-level metrics. + l4Services metric.Int64UpDownCounter + + // L4 TCP connection-level metrics. + tcpActiveConns metric.Int64UpDownCounter + tcpConnsTotal metric.Int64Counter + tcpConnDuration metric.Int64Histogram + tcpBytesTotal metric.Int64Counter + + // L4 UDP session-level metrics. + udpActiveSess metric.Int64UpDownCounter + udpSessionsTotal metric.Int64Counter + udpPacketsTotal metric.Int64Counter + udpBytesTotal metric.Int64Counter + mappingsMux sync.Mutex mappingPaths map[string]int } +// New creates a Metrics instance using the given OpenTelemetry meter. func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { - requestsTotal, err := meter.Int64Counter( + m := &Metrics{ + ctx: ctx, + mappingPaths: make(map[string]int), + } + + if err := m.initHTTPMetrics(meter); err != nil { + return nil, err + } + if err := m.initL4Metrics(meter); err != nil { + return nil, err + } + + return m, nil +} + +func (m *Metrics) initHTTPMetrics(meter metric.Meter) error { + var err error + + m.requestsTotal, err = meter.Int64Counter( "proxy.http.request.counter", metric.WithUnit("1"), metric.WithDescription("Total number of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - activeRequests, err := meter.Int64UpDownCounter( + m.activeRequests, err = meter.Int64UpDownCounter( "proxy.http.active_requests", metric.WithUnit("1"), metric.WithDescription("Current in-flight requests handled by the netbird proxy"), ) if err != nil { - return nil, err + return err } - configuredDomains, err := meter.Int64UpDownCounter( + m.configuredDomains, err = meter.Int64UpDownCounter( "proxy.domains.count", metric.WithUnit("1"), metric.WithDescription("Current number of domains configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - totalPaths, err := meter.Int64UpDownCounter( + m.totalPaths, err = meter.Int64UpDownCounter( "proxy.paths.count", metric.WithUnit("1"), metric.WithDescription("Total number of paths configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - requestDuration, err := meter.Int64Histogram( + m.requestDuration, err = meter.Int64Histogram( "proxy.http.request.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - backendDuration, err := meter.Int64Histogram( + m.backendDuration, err = meter.Int64Histogram( "proxy.backend.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of peer round trip time from the netbird proxy"), ) if err != nil { - return nil, err + return err } - certificateIssueDuration, err := meter.Int64Histogram( + m.certificateIssueDuration, err = meter.Int64Histogram( "proxy.certificate.issue.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of ACME certificate issuance"), ) + return err +} + +func (m *Metrics) initL4Metrics(meter metric.Meter) error { + var err error + + m.l4Services, err = meter.Int64UpDownCounter( + "proxy.l4.services.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"), + ) if err != nil { - return nil, err + return err } - return &Metrics{ - ctx: ctx, - requestsTotal: requestsTotal, - activeRequests: activeRequests, - configuredDomains: configuredDomains, - totalPaths: totalPaths, - requestDuration: requestDuration, - backendDuration: backendDuration, - certificateIssueDuration: certificateIssueDuration, - mappingPaths: make(map[string]int), - }, nil + m.tcpActiveConns, err = meter.Int64UpDownCounter( + "proxy.tcp.active_connections", + metric.WithUnit("1"), + metric.WithDescription("Current number of active TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpConnsTotal, err = meter.Int64Counter( + "proxy.tcp.connections.total", + metric.WithUnit("1"), + metric.WithDescription("Total TCP/TLS relay connections by result and account"), + ) + if err != nil { + return err + } + + m.tcpConnDuration, err = meter.Int64Histogram( + "proxy.tcp.connection.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpBytesTotal, err = meter.Int64Counter( + "proxy.tcp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"), + ) + if err != nil { + return err + } + + m.udpActiveSess, err = meter.Int64UpDownCounter( + "proxy.udp.active_sessions", + metric.WithUnit("1"), + metric.WithDescription("Current number of active UDP relay sessions"), + ) + if err != nil { + return err + } + + m.udpSessionsTotal, err = meter.Int64Counter( + "proxy.udp.sessions.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP relay sessions by result and account"), + ) + if err != nil { + return err + } + + m.udpPacketsTotal, err = meter.Int64Counter( + "proxy.udp.packets.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP packets relayed by direction"), + ) + if err != nil { + return err + } + + m.udpBytesTotal, err = meter.Int64Counter( + "proxy.udp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through UDP relay by direction"), + ) + return err } type responseInterceptor struct { @@ -120,6 +226,13 @@ func (w *responseInterceptor) Write(b []byte) (int, error) { return size, err } +// Unwrap returns the underlying ResponseWriter so http.ResponseController +// can reach through to the original writer for Hijack/Flush operations. +func (w *responseInterceptor) Unwrap() http.ResponseWriter { + return w.PassthroughWriter +} + +// Middleware wraps an HTTP handler with request metrics. func (m *Metrics) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.requestsTotal.Add(m.ctx, 1) @@ -144,6 +257,7 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } +// RoundTripper wraps an http.RoundTripper with backend duration metrics. func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { start := time.Now() @@ -156,6 +270,7 @@ func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { }) } +// AddMapping records that a domain mapping was added. func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() @@ -175,13 +290,13 @@ func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingPaths[mapping.Host] = newPathCount } +// RemoveMapping records that a domain mapping was removed. func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() oldPathCount, exists := m.mappingPaths[mapping.Host] if !exists { - // Nothing to remove return } @@ -195,3 +310,80 @@ func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { func (m *Metrics) RecordCertificateIssuance(duration time.Duration) { m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds()) } + +// L4ServiceAdded increments the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// L4ServiceRemoved decrements the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// TCPRelayStarted records a new TCP relay connection starting. +func (m *Metrics) TCPRelayStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration. +func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct)) + m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct)) + m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend"))) + m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client"))) +} + +// TCPRelayDialError records a dial failure for a TCP relay. +func (m *Metrics) TCPRelayDialError(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// TCPRelayRejected records a rejected TCP relay (semaphore full). +func (m *Metrics) TCPRelayRejected(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPSessionStarted records a new UDP session starting. +func (m *Metrics) UDPSessionStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// UDPSessionEnded records a UDP session ending. +func (m *Metrics) UDPSessionEnded(accountID types.AccountID) { + m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID)))) +} + +// UDPSessionDialError records a dial failure for a UDP session. +func (m *Metrics) UDPSessionDialError(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// UDPSessionRejected records a rejected UDP session (limit or rate limited). +func (m *Metrics) UDPSessionRejected(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes. +func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) { + dir := attribute.String("direction", string(direction)) + m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir)) + m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir)) +} diff --git a/proxy/internal/netutil/errors.go b/proxy/internal/netutil/errors.go new file mode 100644 index 000000000..ff24e33d4 --- /dev/null +++ b/proxy/internal/netutil/errors.go @@ -0,0 +1,40 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + "syscall" +) + +// ValidatePort converts an int32 proto port to uint16, returning an error +// if the value is out of the valid 1–65535 range. +func ValidatePort(port int32) (uint16, error) { + if port <= 0 || port > math.MaxUint16 { + return 0, fmt.Errorf("invalid port %d: must be 1–65535", port) + } + return uint16(port), nil +} + +// IsExpectedError returns true for errors that are normal during +// connection teardown and should not be logged as warnings. +func IsExpectedError(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} + +// IsTimeout checks whether the error is a network timeout. +func IsTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/proxy/internal/netutil/errors_test.go b/proxy/internal/netutil/errors_test.go new file mode 100644 index 000000000..7d6be10ff --- /dev/null +++ b/proxy/internal/netutil/errors_test.go @@ -0,0 +1,92 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + port int32 + want uint16 + wantErr bool + }{ + {"valid min", 1, 1, false}, + {"valid mid", 8080, 8080, false}, + {"valid max", 65535, 65535, false}, + {"zero", 0, 0, true}, + {"negative", -1, 0, true}, + {"too large", 65536, 0, true}, + {"way too large", 100000, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidatePort(tt.port) + if tt.wantErr { + assert.Error(t, err) + assert.Zero(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestIsExpectedError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net.ErrClosed", net.ErrClosed, true}, + {"context.Canceled", context.Canceled, true}, + {"io.EOF", io.EOF, true}, + {"ECONNRESET", syscall.ECONNRESET, true}, + {"EPIPE", syscall.EPIPE, true}, + {"ECONNABORTED", syscall.ECONNABORTED, true}, + {"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true}, + {"unexpected EOF", io.ErrUnexpectedEOF, false}, + {"generic error", errors.New("something"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsExpectedError(tt.err)) + }) + } +} + +type timeoutErr struct{ timeout bool } + +func (e *timeoutErr) Error() string { return "timeout" } +func (e *timeoutErr) Timeout() bool { return e.timeout } +func (e *timeoutErr) Temporary() bool { return false } + +func TestIsTimeout(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net timeout", &timeoutErr{timeout: true}, true}, + {"net non-timeout", &timeoutErr{timeout: false}, false}, + {"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true}, + {"generic error", errors.New("not a timeout"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTimeout(tt.err)) + }) + } +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 22ebbf371..4a61f6bcf 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "net/netip" "sync" "github.com/netbirdio/netbird/proxy/internal/types" @@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string { type CapturedData struct { mu sync.RWMutex RequestID string - ServiceId string + ServiceId types.ServiceID AccountId types.AccountID Origin ResponseOrigin - ClientIP string + ClientIP netip.Addr UserID string AuthMethod string } @@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string { } // SetServiceId safely sets the service ID -func (c *CapturedData) SetServiceId(serviceId string) { +func (c *CapturedData) SetServiceId(serviceId types.ServiceID) { c.mu.Lock() defer c.mu.Unlock() c.ServiceId = serviceId } // GetServiceId safely gets the service ID -func (c *CapturedData) GetServiceId() string { +func (c *CapturedData) GetServiceId() types.ServiceID { c.mu.RLock() defer c.mu.RUnlock() return c.ServiceId @@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin { } // SetClientIP safely sets the resolved client IP. -func (c *CapturedData) SetClientIP(ip string) { +func (c *CapturedData) SetClientIP(ip netip.Addr) { c.mu.Lock() defer c.mu.Unlock() c.ClientIP = ip } // GetClientIP safely gets the resolved client IP. -func (c *CapturedData) GetClientIP() string { +func (c *CapturedData) GetClientIP() netip.Addr { c.mu.RLock() defer c.mu.RUnlock() return c.ClientIP @@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData { return data } -func withServiceId(ctx context.Context, serviceId string) context.Context { +func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context { return context.WithValue(ctx, serviceIdKey, serviceId) } -func ServiceIdFromContext(ctx context.Context) string { +func ServiceIdFromContext(ctx context.Context) types.ServiceID { v := ctx.Value(serviceIdKey) - serviceId, ok := v.(string) + serviceId, ok := v.(types.ServiceID) if !ok { return "" } diff --git a/proxy/internal/proxy/proxy_bench_test.go b/proxy/internal/proxy/proxy_bench_test.go index 5af2167e6..b59ef75c0 100644 --- a/proxy/internal/proxy/proxy_bench_test.go +++ b/proxy/internal/proxy/proxy_bench_test.go @@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil) rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: map[string]*proxy.PathTarget{ @@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) { target = id } rp.AddMapping(proxy.Mapping{ - ID: id, + ID: types.ServiceID(id), AccountID: types.AccountID(rand.Text()), Host: host, Paths: map[string]*proxy.PathTarget{ @@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) { } } rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: paths, diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index b0001d5b9..1ee9b2a42 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" ) @@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = roundtrip.WithSkipTLSVerify(ctx) } if pt.RequestTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout) - defer cancel() + ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) } rewriteMatchedPath := result.matchedPath @@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Header.Set(k, v) } - clientIP := extractClientIP(r.In.RemoteAddr) + clientIP := extractHostIP(r.In.RemoteAddr) - if IsTrustedProxy(clientIP, p.trustedProxies) { + if isTrustedAddr(clientIP, p.trustedProxies) { p.setTrustedForwardingHeaders(r, clientIP) } else { p.setUntrustedForwardingHeaders(r, clientIP) @@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string { // setTrustedForwardingHeaders appends to the existing forwarding header chain // and preserves upstream-provided headers when the direct connection is from // a trusted proxy. -func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() + // Append the direct connection IP to the existing X-Forwarded-For chain. if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" { - r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP) + r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr) } else { - r.Out.Header.Set("X-Forwarded-For", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) } // Preserve upstream X-Real-IP if present; otherwise resolve through the chain. @@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli r.Out.Header.Set("X-Real-IP", realIP) } else { resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies) - r.Out.Header.Set("X-Real-IP", resolved) + r.Out.Header.Set("X-Real-IP", resolved.String()) } // Preserve upstream X-Forwarded-Host if present. @@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli // sets them fresh based on the direct connection. This is the default // behavior when no trusted proxies are configured or the direct connection // is from an untrusted source. -func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) - r.Out.Header.Set("X-Forwarded-For", clientIP) - r.Out.Header.Set("X-Real-IP", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) + r.Out.Header.Set("X-Real-IP", ipStr) r.Out.Header.Set("X-Forwarded-Host", r.In.Host) r.Out.Header.Set("X-Forwarded-Proto", proto) r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) @@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) { } } -// extractClientIP extracts the IP address from an http.Request.RemoteAddr -// which is always in host:port format. -func extractClientIP(remoteAddr string) string { - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return remoteAddr - } - return ip -} - // extractForwardedPort returns the port from the Host header if present, // otherwise defaults to the standard port for the resolved protocol. func extractForwardedPort(host, resolvedProto string) string { @@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { web.ServeErrorPage(w, r, code, title, message, requestID, status) } -// getClientIP retrieves the resolved client IP from context. +// getClientIP retrieves the resolved client IP string from context. func getClientIP(r *http.Request) string { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { - return capturedData.GetClientIP() + if ip := capturedData.GetClientIP(); ip.IsValid() { + return ip.String() + } } return "" } diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index be2fb9105..b05ead198 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { }) } -func TestExtractClientIP(t *testing.T) { +func TestExtractHostIP(t *testing.T) { tests := []struct { name string remoteAddr string - expected string + expected netip.Addr }{ - {"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"}, - {"IPv6 with port", "[::1]:12345", "::1"}, - {"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"}, - {"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"}, - {"IPv6 without brackets fallback", "::1", "::1"}, - {"empty string fallback", "", ""}, - {"public IP", "203.0.113.50:9999", "203.0.113.50"}, + {"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")}, + {"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")}, + {"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")}, + {"empty string fallback", "", netip.Addr{}}, + {"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr)) + assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr)) }) } } diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 58b92ff9e..1513fbe45 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -30,8 +30,9 @@ type PathTarget struct { CustomHeaders map[string]string } +// Mapping describes how a domain is routed by the HTTP reverse proxy. type Mapping struct { - ID string + ID types.ServiceID AccountID types.AccountID Host string Paths map[string]*PathTarget @@ -42,7 +43,7 @@ type Mapping struct { type targetResult struct { target *PathTarget matchedPath string - serviceID string + serviceID types.ServiceID accountID types.AccountID passHostHeader bool rewriteRedirects bool @@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) { p.mappings[m.Host] = m } -func (p *ReverseProxy) RemoveMapping(m Mapping) { +// RemoveMapping removes the mapping for the given host and reports whether it existed. +func (p *ReverseProxy) RemoveMapping(m Mapping) bool { p.mappingsMux.Lock() defer p.mappingsMux.Unlock() + if _, ok := p.mappings[m.Host]; !ok { + return false + } delete(p.mappings, m.Host) + return true } diff --git a/proxy/internal/proxy/trustedproxy.go b/proxy/internal/proxy/trustedproxy.go index ad9a5b6c0..0fe693f90 100644 --- a/proxy/internal/proxy/trustedproxy.go +++ b/proxy/internal/proxy/trustedproxy.go @@ -7,21 +7,11 @@ import ( // IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes. func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { - if len(trusted) == 0 { - return false - } - addr, err := netip.ParseAddr(ipStr) - if err != nil { + if err != nil || len(trusted) == 0 { return false } - - for _, prefix := range trusted { - if prefix.Contains(addr) { - return true - } - } - return false + return isTrustedAddr(addr.Unmap(), trusted) } // ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list. @@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { // // If the trusted list is empty or remoteAddr is not trusted, it returns the // remoteAddr IP directly (ignoring any forwarding headers). -func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { - remoteIP := extractClientIP(remoteAddr) +func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr { + remoteIP := extractHostIP(remoteAddr) - if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) { + if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) { return remoteIP } @@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { if ip == "" { continue } - if !IsTrustedProxy(ip, trusted) { - return ip + addr, err := netip.ParseAddr(ip) + if err != nil { + continue + } + addr = addr.Unmap() + if !isTrustedAddr(addr, trusted) { + return addr } } // All IPs in XFF are trusted; return the leftmost as best guess. if first := strings.TrimSpace(parts[0]); first != "" { - return first + if addr, err := netip.ParseAddr(first); err == nil { + return addr.Unmap() + } } return remoteIP } + +// extractHostIP parses the IP from a host:port string and returns it unmapped. +func extractHostIP(hostPort string) netip.Addr { + if ap, err := netip.ParseAddrPort(hostPort); err == nil { + return ap.Addr().Unmap() + } + if addr, err := netip.ParseAddr(hostPort); err == nil { + return addr.Unmap() + } + return netip.Addr{} +} + +// isTrustedAddr checks if the given address falls within any of the trusted prefixes. +func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool { + if !addr.IsValid() { + return false + } + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} diff --git a/proxy/internal/proxy/trustedproxy_test.go b/proxy/internal/proxy/trustedproxy_test.go index 827b7babf..35ed1f5c2 100644 --- a/proxy/internal/proxy/trustedproxy_test.go +++ b/proxy/internal/proxy/trustedproxy_test.go @@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) { remoteAddr string xff string trusted []netip.Prefix - want string + want netip.Addr }{ { name: "empty trusted list returns RemoteAddr", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4", trusted: nil, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "untrusted RemoteAddr ignores XFF", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4, 10.0.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with single client in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr walks past trusted entries in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50, 10.0.0.2, 172.16.0.5", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", remoteAddr: "10.0.0.1:5000", xff: "", trusted: trusted, - want: "10.0.0.1", + want: netip.MustParseAddr("10.0.0.1"), }, { name: "all XFF IPs trusted returns leftmost", remoteAddr: "10.0.0.1:5000", xff: "10.0.0.2, 172.16.0.1, 10.0.0.3", trusted: trusted, - want: "10.0.0.2", + want: netip.MustParseAddr("10.0.0.2"), }, { name: "XFF with whitespace", remoteAddr: "10.0.0.1:5000", xff: " 203.0.113.50 , 10.0.0.2 ", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "XFF with empty segments", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50,,10.0.0.2", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "multi-hop with mixed trust", remoteAddr: "10.0.0.1:5000", xff: "8.8.8.8, 203.0.113.50, 172.16.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "RemoteAddr without port", remoteAddr: "10.0.0.1", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, } for _, tt := range tests { diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 57770f4a5..e38e3dc4e 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -14,11 +15,12 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/embed" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -26,7 +28,22 @@ import ( const deviceNamePrefix = "ingress-proxy-" // backendKey identifies a backend by its host:port from the target URL. -type backendKey = string +type backendKey string + +// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service) +// that holds a reference to an embedded NetBird client. Callers should use the +// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions. +type ServiceKey string + +// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service. +func DomainServiceKey(domain string) ServiceKey { + return ServiceKey("domain:" + domain) +} + +// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP). +func L4ServiceKey(id types.ServiceID) ServiceKey { + return ServiceKey("l4:" + id) +} var ( // ErrNoAccountID is returned when a request context is missing the account ID. @@ -39,24 +56,24 @@ var ( ErrTooManyInflight = errors.New("too many in-flight requests") ) -// domainInfo holds metadata about a registered domain. -type domainInfo struct { - serviceID string +// serviceInfo holds metadata about a registered service. +type serviceInfo struct { + serviceID types.ServiceID } -type domainNotification struct { - domain domain.Domain - serviceID string +type serviceNotification struct { + key ServiceKey + serviceID types.ServiceID } -// clientEntry holds an embedded NetBird client and tracks which domains use it. +// clientEntry holds an embedded NetBird client and tracks which services use it. type clientEntry struct { client *embed.Client transport *http.Transport // insecureTransport is a clone of transport with TLS verification disabled, // used when per-target skip_tls_verify is set. insecureTransport *http.Transport - domains map[domain.Domain]domainInfo + services map[ServiceKey]serviceInfo createdAt time.Time started bool // Per-backend in-flight limiting keyed by target host:port. @@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo // ClientConfig holds configuration for the embedded NetBird client. type ClientConfig struct { MgmtAddr string - WGPort int + WGPort uint16 PreSharedKey string } type statusNotifier interface { - NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error + NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error } type managementClient interface { @@ -107,7 +124,7 @@ type managementClient interface { // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. -// Clients are keyed by AccountID, allowing multiple domains to share the same connection. +// Clients are keyed by AccountID, allowing multiple services to share the same connection. type NetBird struct { proxyID string proxyAddr string @@ -124,11 +141,11 @@ type NetBird struct { // ClientDebugInfo contains debug information about a client. type ClientDebugInfo struct { - AccountID types.AccountID - DomainCount int - Domains domain.List - HasClient bool - CreatedAt time.Time + AccountID types.AccountID + ServiceCount int + ServiceKeys []string + HasClient bool + CreatedAt time.Time } // accountIDContextKey is the context key for storing the account ID. @@ -137,37 +154,37 @@ type accountIDContextKey struct{} // skipTLSVerifyContextKey is the context key for requesting insecure TLS. type skipTLSVerifyContextKey struct{} -// AddPeer registers a domain for an account. If the account doesn't have a client yet, +// AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. -// Multiple domains can share the same client. -func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error { +// Multiple services can share the same client. +func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { + si := serviceInfo{serviceID: serviceID} + n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { - // Client already exists for this account, just register the domain - entry.domains[d] = domainInfo{serviceID: serviceID} + entry.services[key] = si started := entry.started n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("registered domain with existing client") + "account_id": accountID, + "service_key": key, + }).Debug("registered service with existing client") - // If client is already started, notify this domain as connected immediately if started && n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).WithError(err).Warn("failed to notify status for existing client") } } return nil } - entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID) + entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if err != nil { n.clientsMux.Unlock() return err @@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).Info("created new client for account") // Attempt to start the client in the background; if this fails we will @@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma // createClientEntry generates a WireGuard keypair, authenticates with management, // and creates an embedded NetBird client. Must be called with clientsMux held. -func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) { +func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { + serviceID := si.serviceID n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, @@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }).Debug("authenticating new proxy peer with management") resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ - ServiceId: serviceID, + ServiceId: string(serviceID), AccountId: string(accountID), Token: authToken, WireguardPublicKey: publicKey.String(), @@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. + wgPort := int(n.clientCfg.WGPort) client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.clientCfg.WGPort, + WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { @@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. transport := &http.Transport{ - DialContext: client.DialContext, + DialContext: dialWithTimeout(client.DialContext), ForceAttemptHTTP2: true, MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, @@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account return &clientEntry{ client: client, - domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}}, + services: map[ServiceKey]serviceInfo{key: si}, transport: transport, insecureTransport: insecureTransport, createdAt: time.Now(), @@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }, nil } -// runClientStartup starts the client and notifies registered domains on success. +// runClientStartup starts the client and notifies registered services on success. func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI return } - // Mark client as started and collect domains to notify outside the lock. + // Mark client as started and collect services to notify outside the lock. n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.started = true } - var domainsToNotify []domainNotification + var toNotify []serviceNotification if exists { - for dom, info := range entry.domains { - domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID}) + for key, info := range entry.services { + toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } n.clientsMux.Unlock() @@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI if n.statusNotifier == nil { return } - for _, dn := range domainsToNotify { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil { + for _, sn := range toNotify { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).WithError(err).Warn("failed to notify tunnel connection status") } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).Info("notified management about tunnel connection") } } } -// RemovePeer unregisters a domain from an account. The client is only stopped -// when no domains are using it anymore. -func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error { +// RemovePeer unregisters a service from an account. The client is only stopped +// when no services are using it anymore. +func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error { n.clientsMux.Lock() entry, exists := n.clients[accountID] @@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d return nil } - // Get domain info before deleting - domInfo, domainExists := entry.domains[d] - if !domainExists { + si, svcExists := entry.services[key] + if !svcExists { n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("remove peer: domain not registered") + "account_id": accountID, + "service_key": key, + }).Debug("remove peer: service not registered") return nil } - delete(entry.domains, d) - - // If there are still domains using this client, keep it running - if len(entry.domains) > 0 { - n.clientsMux.Unlock() + delete(entry.services, key) + stopClient := len(entry.services) == 0 + var client *embed.Client + var transport, insecureTransport *http.Transport + if stopClient { + n.logger.WithField("account_id", accountID).Info("stopping client, no more services") + client = entry.client + transport = entry.transport + insecureTransport = entry.insecureTransport + delete(n.clients, accountID) + } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - "remaining_domains": len(entry.domains), - }).Debug("unregistered domain, client still in use") - - // Notify this domain as disconnected - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") - } - } - return nil + "account_id": accountID, + "service_key": key, + "remaining_services": len(entry.services), + }).Debug("unregistered service, client still in use") } - - // No more domains using this client, stop it - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).Info("stopping client, no more domains") - - client := entry.client - transport := entry.transport - insecureTransport := entry.insecureTransport - delete(n.clients, accountID) n.clientsMux.Unlock() - // Notify disconnection before stopping - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") + n.notifyDisconnect(ctx, accountID, key, si.serviceID) + + if stopClient { + transport.CloseIdleConnections() + insecureTransport.CloseIdleConnections() + if err := client.Stop(ctx); err != nil { + n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client") } } - transport.CloseIdleConnections() - insecureTransport.CloseIdleConnections() - - if err := client.Stop(ctx); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).WithError(err).Warn("failed to stop netbird client") - } - return nil } +func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) { + if n.statusNotifier == nil { + return + } + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil { + if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound { + n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "service_key": key, + }).WithError(err).Warn("failed to notify tunnel disconnection status") + } + } +} + // RoundTrip implements http.RoundTripper. It looks up the client for the account // specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { @@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { } n.clientsMux.RUnlock() - release, ok := entry.acquireInflight(req.URL.Host) + release, ok := entry.acquireInflight(backendKey(req.URL.Host)) defer release() if !ok { return nil, ErrTooManyInflight @@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool { return exists } -// DomainCount returns the number of domains registered for the given account. +// ServiceCount returns the number of services registered for the given account. // Returns 0 if the account has no client. -func (n *NetBird) DomainCount(accountID types.AccountID) int { +func (n *NetBird) ServiceCount(accountID types.AccountID) int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return 0 } - return len(entry.domains) + return len(entry.services) } // ClientCount returns the total number of active clients. @@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { result := make(map[types.AccountID]ClientDebugInfo) for accountID, entry := range n.clients { - domains := make(domain.List, 0, len(entry.domains)) - for d := range entry.domains { - domains = append(domains, d) + keys := make([]string, 0, len(entry.services)) + for k := range entry.services { + keys = append(keys, string(k)) } result[accountID] = ClientDebugInfo{ - AccountID: accountID, - DomainCount: len(entry.domains), - Domains: domains, - HasClient: entry.client != nil, - CreatedAt: entry.createdAt, + AccountID: accountID, + ServiceCount: len(entry.services), + ServiceKeys: keys, + HasClient: entry.client != nil, + CreatedAt: entry.createdAt, } } return result @@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L } } +// dialWithTimeout wraps a DialContext function so that any dial timeout +// stored in the context (via types.WithDialTimeout) is applied only to +// the connection establishment phase, not the full request lifetime. +func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if d, ok := types.DialTimeoutFromContext(ctx); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + return dial(ctx, network, addr) + } +} + // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { return context.WithValue(ctx, accountIDContextKey{}, accountID) diff --git a/proxy/internal/roundtrip/netbird_bench_test.go b/proxy/internal/roundtrip/netbird_bench_test.go index e89213c33..330ea0332 100644 --- a/proxy/internal/roundtrip/netbird_bench_test.go +++ b/proxy/internal/roundtrip/netbird_bench_test.go @@ -1,6 +1,7 @@ package roundtrip import ( + "context" "crypto/rand" "math/big" "sync" @@ -8,7 +9,6 @@ import ( "time" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" ) // Simple benchmark for comparison with AddPeer contention. @@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } // Launch workers that continuously call AddPeer with new random accountIDs. + ctx, cancel := context.WithCancel(b.Context()) var wg sync.WaitGroup for range addPeerWorkers { - wg.Go(func() { - for { - if err := nb.AddPeer(b.Context(), + wg.Add(1) + go func() { + defer wg.Done() + for ctx.Err() == nil { + if err := nb.AddPeer(ctx, types.AccountID(rand.Text()), - domain.Domain(rand.Text()), + ServiceKey(rand.Text()), rand.Text(), - rand.Text()); err != nil { - b.Log(err) + types.ServiceID(rand.Text())); err != nil { + return } } - }) + }() } // Benchmark calling HasClient during AddPeer contention. @@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } }) b.StopTimer() + cancel() + wg.Wait() } diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 0a742c2fa..5444f6c11 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -27,16 +26,15 @@ type mockStatusNotifier struct { } type statusCall struct { - accountID string - serviceID string - domain string + accountID types.AccountID + serviceID types.ServiceID connected bool } -func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error { +func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { m.mu.Lock() defer m.mu.Unlock() - m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected}) + m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected}) return nil } @@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { // Initially no client exists. assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") - // Add first domain - this should create a new client. - // Note: This will fail to actually connect since we use an invalid URL, - // but the client entry should still be created. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service - this should create a new client. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID), "should have client after AddPeer") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add first domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - assert.Equal(t, 1, nb.DomainCount(accountID)) + assert.Equal(t, 1, nb.ServiceCount(accountID)) - // Add second domain for the same account - should reuse existing client. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + // Add second service for the same account - should reuse existing client. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service") - // Add third domain. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + // Add third service. + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain") + assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service") // Still only one client. assert.True(t, nb.HasClient(accountID)) @@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) { account1 := types.AccountID("account-1") account2 := types.AccountID("account-2") - // Add domain for account 1. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add service for account 1. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Add domain for account 2. - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2") + // Add service for account 2. + err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2")) require.NoError(t, err) // Both accounts should have their own clients. assert.True(t, nb.HasClient(account1), "account1 should have client") assert.True(t, nb.HasClient(account2), "account2 should have client") - assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1") - assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1") } -func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) { +func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add multiple domains. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add multiple services. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID)) + assert.Equal(t, 3, nb.ServiceCount(accountID)) - // Remove one domain - client should remain. + // Remove one service - client should remain. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain") - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2") + assert.True(t, nb.HasClient(accountID), "client should remain after removing one service") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2") - // Remove another domain - client should still remain. + // Remove another service - client should still remain. err = nb.RemovePeer(context.Background(), accountID, "domain2.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.True(t, nb.HasClient(accountID), "client should remain after removing second service") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } -func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) { +func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add single domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add single service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) - // Remove the only domain - client should be removed. - // Note: Stop() may fail since the client never actually connected, - // but the entry should still be removed from the map. + // Remove the only service - client should be removed. _ = nb.RemovePeer(context.Background(), accountID, "domain1.test") - // After removing all domains, client should be gone. - assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + // After removing all services, client should be gone. + assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") } func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { @@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { assert.NoError(t, err, "removing from non-existent account should not error") } -func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) { +func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add one domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add one service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Remove non-existent domain - should not affect existing domain. - err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test")) + // Remove non-existent service - should not affect existing service. + err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test") require.NoError(t, err) - // Original domain should still be registered. + // Original service should still be registered. assert.True(t, nb.HasClient(accountID)) - assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain") + assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain") } func TestWithAccountID_AndAccountIDFromContext(t *testing.T) { @@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) { account2 := types.AccountID("account-2") account3 := types.AccountID("account-3") - // Add domains for multiple accounts. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1") + // Add services for multiple accounts. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3") + err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3")) require.NoError(t, err) assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients") // Stop all clients. - // Note: StopAll may return errors since clients never actually connected, - // but the clients should still be removed from the map. _ = nb.StopAll(context.Background()) assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll") @@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) { assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients") // Add clients for different accounts. - err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1") + err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.Equal(t, 1, nb.ClientCount()) - err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) assert.Equal(t, 2, nb.ClientCount()) - // Adding domain to existing account should not increase count. - err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b") + // Adding service to existing account should not increase count. + err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b")) require.NoError(t, err) - assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count") + assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count") } func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) { @@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - // Add first domain — creates a new client entry. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + // Add first service — creates a new client entry. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) // Manually mark client as started to simulate background startup completing. @@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { nb.clients[accountID].started = true nb.clientsMux.Unlock() - // Add second domain — should notify immediately since client is already started. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + // Add second service — should notify immediately since client is already started. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, string(accountID), calls[0].accountID) - assert.Equal(t, "svc-2", calls[0].serviceID) - assert.Equal(t, "domain2.test", calls[0].domain) + assert.Equal(t, accountID, calls[0].accountID) + assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID) assert.True(t, calls[0].connected) } @@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) - // Remove one domain — client stays, but disconnection notification fires. + // Remove one service — client stays, but disconnection notification fires. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, "domain1.test", calls[0].domain) + assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID) assert.False(t, calls[0].connected) } diff --git a/proxy/internal/tcp/bench_test.go b/proxy/internal/tcp/bench_test.go new file mode 100644 index 000000000..049f8395d --- /dev/null +++ b/proxy/internal/tcp/bench_test.go @@ -0,0 +1,133 @@ +package tcp + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "testing" +) + +// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real +// TLS ClientHello and extracting the SNI. This is the per-connection cost +// added to every TLS connection on the main listener. +func BenchmarkPeekClientHello_TLS(b *testing.B) { + // Pre-generate a ClientHello by capturing what crypto/tls sends. + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + var hello []byte + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + hello = make([]byte, n) + copy(hello, buf[:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(hello) + conn := &readerConn{Reader: r} + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + // Simulate draining the peeked bytes (what the HTTP server would do). + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS +// connections that hit the fast non-handshake exit path. +func BenchmarkPeekClientHello_NonTLS(b *testing.B) { + httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(httpReq) + conn := &readerConn{Reader: r} + _, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn +// wrapper compared to a plain connection read. The peeked bytes use +// io.MultiReader which adds one indirection per Read call. +func BenchmarkPeekedConn_Read(b *testing.B) { + data := make([]byte, 4096) + peeked := make([]byte, 512) + buf := make([]byte, 1024) + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(data) + conn := &readerConn{Reader: r} + pc := newPeekedConn(conn, peeked) + for { + _, err := pc.Read(buf) + if err != nil { + break + } + } + } +} + +// BenchmarkExtractSNI measures just the in-memory SNI parsing cost, +// excluding I/O. +func BenchmarkExtractSNI(b *testing.B) { + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + payload := make([]byte, n-tlsRecordHeaderLen) + copy(payload, buf[tlsRecordHeaderLen:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + sni := extractSNI(payload) + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + } +} + +// readerConn wraps an io.Reader as a net.Conn for benchmarking. +// Only Read is functional; all other methods are no-ops. +type readerConn struct { + io.Reader + net.Conn +} + +func (c *readerConn) Read(b []byte) (int, error) { + return c.Reader.Read(b) +} diff --git a/proxy/internal/tcp/chanlistener.go b/proxy/internal/tcp/chanlistener.go new file mode 100644 index 000000000..ee64bc0a2 --- /dev/null +++ b/proxy/internal/tcp/chanlistener.go @@ -0,0 +1,76 @@ +package tcp + +import ( + "net" + "sync" +) + +// chanListener implements net.Listener by reading connections from a channel. +// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS. +type chanListener struct { + ch chan net.Conn + addr net.Addr + once sync.Once + closed chan struct{} +} + +func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener { + return &chanListener{ + ch: ch, + addr: addr, + closed: make(chan struct{}), + } +} + +// Accept waits for and returns the next connection from the channel. +func (l *chanListener) Accept() (net.Conn, error) { + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + return conn, nil + case <-l.closed: + // Drain buffered connections before returning. + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + _ = conn.Close() + default: + return nil, net.ErrClosed + } + } + } + } +} + +// Close signals the listener to stop accepting connections and drains +// any buffered connections that have not yet been accepted. +func (l *chanListener) Close() error { + l.once.Do(func() { + close(l.closed) + for { + select { + case conn, ok := <-l.ch: + if !ok { + return + } + _ = conn.Close() + default: + return + } + } + }) + return nil +} + +// Addr returns the listener's network address. +func (l *chanListener) Addr() net.Addr { + return l.addr +} + +var _ net.Listener = (*chanListener)(nil) diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go new file mode 100644 index 000000000..26f3e5c7c --- /dev/null +++ b/proxy/internal/tcp/peekedconn.go @@ -0,0 +1,39 @@ +package tcp + +import ( + "bytes" + "io" + "net" +) + +// peekedConn wraps a net.Conn and prepends previously peeked bytes +// so that readers see the full original stream transparently. +type peekedConn struct { + net.Conn + reader io.Reader +} + +func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn { + return &peekedConn{ + Conn: conn, + reader: io.MultiReader(bytes.NewReader(peeked), conn), + } +} + +// Read replays the peeked bytes first, then reads from the underlying conn. +func (c *peekedConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// CloseWrite delegates to the underlying connection if it supports +// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn +// as an interface hides the concrete type's CloseWrite method, making +// half-close a silent no-op for all SNI-routed connections. +func (c *peekedConn) CloseWrite() error { + if hc, ok := c.Conn.(halfCloser); ok { + return hc.CloseWrite() + } + return nil +} + +var _ halfCloser = (*peekedConn)(nil) diff --git a/proxy/internal/tcp/proxyprotocol.go b/proxy/internal/tcp/proxyprotocol.go new file mode 100644 index 000000000..699b75a5d --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol.go @@ -0,0 +1,29 @@ +package tcp + +import ( + "fmt" + "net" + + "github.com/pires/go-proxyproto" +) + +// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection, +// conveying the real client address. +func writeProxyProtoV2(client, backend net.Conn) error { + tp := proxyproto.TCPv4 + if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil { + tp = proxyproto.TCPv6 + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: tp, + SourceAddr: client.RemoteAddr(), + DestinationAddr: client.LocalAddr(), + } + if _, err := header.WriteTo(backend); err != nil { + return fmt.Errorf("write PROXY protocol v2 header: %w", err) + } + return nil +} diff --git a/proxy/internal/tcp/proxyprotocol_test.go b/proxy/internal/tcp/proxyprotocol_test.go new file mode 100644 index 000000000..f8c48b2ab --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol_test.go @@ -0,0 +1,128 @@ +package tcp + +import ( + "bufio" + "net" + "testing" + + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteProxyProtoV2_IPv4(t *testing.T) { + // Set up a real TCP listener and dial to get connections with real addresses. + ln, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp4", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + // Use a pipe as the backend: write the header to one end, read from the other. + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + // serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination. + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + // Read the PROXY protocol header from the backend read side. + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4") + + // serverConn.RemoteAddr() is the client's address (source in the header). + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + // serverConn.LocalAddr() is the server's address (destination in the header). + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} + +func TestWriteProxyProtoV2_IPv6(t *testing.T) { + // Set up a real TCP6 listener on loopback. + ln, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available:", err) + } + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp6", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6") + + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go new file mode 100644 index 000000000..39949818d --- /dev/null +++ b/proxy/internal/tcp/relay.go @@ -0,0 +1,156 @@ +package tcp + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +// errIdleTimeout is returned when a relay connection is closed due to inactivity. +var errIdleTimeout = errors.New("idle timeout") + +// DefaultIdleTimeout is the default idle timeout for TCP relay connections. +// A zero value disables idle timeout checking. +const DefaultIdleTimeout = 5 * time.Minute + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn). When one copy direction finishes, we signal +// EOF to the remote by closing the write side while keeping the read +// side open so the other direction can drain. +type halfCloser interface { + CloseWrite() error +} + +// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Relay copies data bidirectionally between src and dst until both +// sides are done or the context is canceled. When idleTimeout is +// non-zero, each direction's read is deadline-guarded; if no data +// flows within the timeout the connection is torn down. When one +// direction finishes, it half-closes the write side of the +// destination (if supported) to signal EOF, allowing the other +// direction to drain gracefully before the full connection teardown. +func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + _ = src.Close() + _ = dst.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var errSrcToDst, errDstToSrc error + + go func() { + defer wg.Done() + srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) + halfClose(dst) + cancel() + }() + + go func() { + defer wg.Done() + dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) + halfClose(src) + cancel() + }() + + wg.Wait() + + if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { + logger.Debug("relay closed due to idle timeout") + } + if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { + logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) + } + if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { + logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) + } + + return srcToDst, dstToSrc +} + +// copyWithIdleTimeout copies from src to dst using a pooled buffer. +// When idleTimeout > 0 it sets a read deadline on src before each +// read and treats a timeout as an idle-triggered close. +func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + if idleTimeout <= 0 { + return io.CopyBuffer(dst, src, *bufp) + } + + conn, ok := src.(net.Conn) + if !ok { + return io.CopyBuffer(dst, src, *bufp) + } + + buf := *bufp + var total int64 + for { + if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { + return total, err + } + nr, readErr := src.Read(buf) + if nr > 0 { + n, err := checkedWrite(dst, buf[:nr]) + total += n + if err != nil { + return total, err + } + } + if readErr != nil { + if netutil.IsTimeout(readErr) { + return total, errIdleTimeout + } + return total, readErr + } + } +} + +// checkedWrite writes buf to dst and returns the number of bytes written. +// It guards against short writes and negative counts per io.Copy convention. +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) +} + +// halfClose attempts to half-close the write side of the connection. +// If the connection does not support half-close, this is a no-op. +func halfClose(conn net.Conn) { + if hc, ok := conn.(halfCloser); ok { + // Best-effort; the full close will follow shortly. + _ = hc.CloseWrite() + } +} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go new file mode 100644 index 000000000..e42d65b9d --- /dev/null +++ b/proxy/internal/tcp/relay_test.go @@ -0,0 +1,210 @@ +package tcp + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +func TestRelay_BidirectionalCopy(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + srcData := []byte("hello from src") + dstData := []byte("hello from dst") + + // dst side: write response first, then read + close. + go func() { + _, _ = dstClient.Write(dstData) + buf := make([]byte, 256) + _, _ = dstClient.Read(buf) + dstClient.Close() + }() + + // src side: read the response, then send data + close. + go func() { + buf := make([]byte, 256) + _, _ = srcClient.Read(buf) + _, _ = srcClient.Write(srcData) + srcClient.Close() + }() + + s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + + assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") + assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") +} + +func TestRelay_ContextCancellation(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer srcClient.Close() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + // Cancel should cause Relay to return. + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after context cancellation") + } +} + +func TestRelay_OneSideClosed(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Close src immediately. Relay should complete without hanging. + srcClient.Close() + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after one side closed") + } +} + +func TestRelay_LargeTransfer(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // 1MB of data. + data := make([]byte, 1<<20) + for i := range data { + data[i] = byte(i % 256) + } + + go func() { + _, _ = srcClient.Write(data) + srcClient.Close() + }() + + errCh := make(chan error, 1) + go func() { + received, err := io.ReadAll(dstClient) + if err != nil { + errCh <- err + return + } + if len(received) != len(data) { + errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received)) + return + } + errCh <- nil + dstClient.Close() + }() + + s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") + require.NoError(t, <-errCh) +} + +func TestRelay_IdleTimeout(t *testing.T) { + // Use real TCP connections so SetReadDeadline works (net.Pipe + // does not support deadlines). + srcLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srcLn.Close() + + dstLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer dstLn.Close() + + srcClient, err := net.Dial("tcp", srcLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer srcClient.Close() + + srcServer, err := srcLn.Accept() + if err != nil { + t.Fatal(err) + } + + dstClient, err := net.Dial("tcp", dstLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer dstClient.Close() + + dstServer, err := dstLn.Accept() + if err != nil { + t.Fatal(err) + } + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Send initial data to prove the relay works. + go func() { + _, _ = srcClient.Write([]byte("ping")) + }() + + done := make(chan struct{}) + var s2d, d2s int64 + go func() { + s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + close(done) + }() + + // Read the forwarded data on the dst side. + buf := make([]byte, 64) + n, err := dstClient.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Now stop sending. The relay should close after the idle timeout. + select { + case <-done: + assert.Greater(t, s2d, int64(0), "should have transferred initial data") + _ = d2s + case <-time.After(5 * time.Second): + t.Fatal("Relay did not exit after idle timeout") + } +} + +func TestIsExpectedError(t *testing.T) { + assert.True(t, netutil.IsExpectedError(net.ErrClosed)) + assert.True(t, netutil.IsExpectedError(context.Canceled)) + assert.True(t, netutil.IsExpectedError(io.EOF)) + assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF)) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go new file mode 100644 index 000000000..84fde0731 --- /dev/null +++ b/proxy/internal/tcp/router.go @@ -0,0 +1,570 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// defaultDialTimeout is the fallback dial timeout when no per-route +// timeout is configured. +const defaultDialTimeout = 30 * time.Second + +// SNIHost is a typed key for SNI hostname lookups. +type SNIHost string + +// RouteType specifies how a connection should be handled. +type RouteType int + +const ( + // RouteHTTP routes the connection through the HTTP reverse proxy. + RouteHTTP RouteType = iota + // RouteTCP relays the connection directly to the backend (TLS passthrough). + RouteTCP +) + +const ( + // sniPeekTimeout is the deadline for reading the TLS ClientHello. + sniPeekTimeout = 5 * time.Second + // DefaultDrainTimeout is the default grace period for in-flight relay + // connections to finish during shutdown. + DefaultDrainTimeout = 30 * time.Second + // DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router. + DefaultMaxRelayConns = 4096 + // httpChannelBuffer is the capacity of the channel feeding HTTP connections. + httpChannelBuffer = 4096 +) + +// DialResolver returns a DialContextFunc for the given account. +type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error) + +// Route describes where a connection for a given SNI should be sent. +type Route struct { + Type RouteType + AccountID types.AccountID + ServiceID types.ServiceID + // Domain is the service's configured domain, used for access log entries. + Domain string + // Protocol is the frontend protocol (tcp, tls), used for access log entries. + Protocol accesslog.Protocol + // Target is the backend address for TCP relay (e.g. "10.0.0.5:5432"). + Target string + // ProxyProtocol enables sending a PROXY protocol v2 header to the backend. + ProxyProtocol bool + // DialTimeout overrides the default dial timeout for this route. + // Zero uses defaultDialTimeout. + DialTimeout time.Duration +} + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// RelayObserver receives callbacks for TCP relay lifecycle events. +// All methods must be safe for concurrent use. +type RelayObserver interface { + TCPRelayStarted(accountID types.AccountID) + TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) + TCPRelayDialError(accountID types.AccountID) + TCPRelayRejected(accountID types.AccountID) +} + +// Router accepts raw TCP connections on a shared listener, peeks at +// the TLS ClientHello to extract the SNI, and routes the connection +// to either the HTTP reverse proxy or a direct TCP relay. +type Router struct { + logger *log.Logger + // httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter. + httpCh chan net.Conn + httpListener *chanListener + mu sync.RWMutex + routes map[SNIHost][]Route + fallback *Route + draining bool + dialResolve DialResolver + activeConns sync.WaitGroup + activeRelays sync.WaitGroup + relaySem chan struct{} + drainDone chan struct{} + observer RelayObserver + accessLog l4Logger + // svcCtxs tracks a context per service ID. All relay goroutines for a + // service derive from its context; canceling it kills them immediately. + svcCtxs map[types.ServiceID]context.Context + svcCancels map[types.ServiceID]context.CancelFunc +} + +// NewRouter creates a new SNI-based connection router. +func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router { + httpCh := make(chan net.Conn, httpChannelBuffer) + return &Router{ + logger: logger, + httpCh: httpCh, + httpListener: newChanListener(httpCh, addr), + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// NewPortRouter creates a Router for a dedicated port without an HTTP +// channel. Connections that don't match any SNI route fall through to +// the fallback relay (if set) or are closed. +func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router { + return &Router{ + logger: logger, + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// HTTPListener returns a net.Listener that yields connections routed +// to the HTTP handler. Use this with http.Server.ServeTLS. +func (r *Router) HTTPListener() net.Listener { + return r.httpListener +} + +// AddRoute registers an SNI route. Multiple routes for the same host are +// stored and resolved by priority at lookup time (HTTP > TCP). +// Empty host is ignored to prevent conflicts with ECH/ESNI fallback. +func (r *Router) AddRoute(host SNIHost, route Route) { + if host == "" { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + routes := r.routes[host] + for i, existing := range routes { + if existing.ServiceID == route.ServiceID { + r.cancelServiceLocked(route.ServiceID) + routes[i] = route + return + } + } + r.routes[host] = append(routes, route) +} + +// RemoveRoute removes the route for the given host and service ID. +// Active relay connections for the service are closed immediately. +// If other routes remain for the host, they are preserved. +func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + + r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool { + return route.ServiceID == svcID + }) + if len(r.routes[host]) == 0 { + delete(r.routes, host) + } + r.cancelServiceLocked(svcID) +} + +// SetFallback registers a catch-all route for connections that don't +// match any SNI route. On a port router this handles plain TCP relay; +// on the main router it takes priority over the HTTP channel. +func (r *Router) SetFallback(route Route) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = &route +} + +// RemoveFallback clears the catch-all fallback route and closes any +// active relay connections for the given service. +func (r *Router) RemoveFallback(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = nil + r.cancelServiceLocked(svcID) +} + +// SetObserver sets the relay lifecycle observer. Must be called before Serve. +func (r *Router) SetObserver(obs RelayObserver) { + r.mu.Lock() + defer r.mu.Unlock() + r.observer = obs +} + +// SetAccessLogger sets the L4 access logger. Must be called before Serve. +func (r *Router) SetAccessLogger(l l4Logger) { + r.mu.Lock() + defer r.mu.Unlock() + r.accessLog = l +} + +// getObserver returns the current relay observer under the read lock. +func (r *Router) getObserver() RelayObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + +// IsEmpty returns true when the router has no SNI routes and no fallback. +func (r *Router) IsEmpty() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.fallback == nil +} + +// Serve accepts connections from ln and routes them based on SNI. +// It blocks until ctx is canceled or ln is closed, then drains +// active relay connections up to DefaultDrainTimeout. +func (r *Router) Serve(ctx context.Context, ln net.Listener) error { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + _ = ln.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + case <-done: + } + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ok := r.Drain(DefaultDrainTimeout); !ok { + r.logger.Warn("timed out waiting for connections to drain") + } + return nil + } + r.logger.Debugf("SNI router accept: %v", err) + continue + } + r.activeConns.Add(1) + go func() { + defer r.activeConns.Done() + r.handleConn(ctx, conn) + }() + } +} + +// handleConn peeks at the TLS ClientHello and routes the connection. +func (r *Router) handleConn(ctx context.Context, conn net.Conn) { + // Fast path: when no SNI routes and no HTTP channel exist (pure TCP + // fallback port), skip the TLS peek entirely to avoid read errors on + // non-TLS connections and reduce latency. + if r.isFallbackOnly() { + r.handleUnmatched(ctx, conn) + return + } + + if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil { + r.logger.Debugf("set SNI peek deadline: %v", err) + _ = conn.Close() + return + } + + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + r.logger.Debugf("SNI peek: %v", err) + if wrapped != nil { + r.handleUnmatched(ctx, wrapped) + } else { + _ = conn.Close() + } + return + } + + if err := wrapped.SetReadDeadline(time.Time{}); err != nil { + r.logger.Debugf("clear SNI peek deadline: %v", err) + _ = wrapped.Close() + return + } + + host := SNIHost(sni) + route, ok := r.lookupRoute(host) + if !ok { + r.handleUnmatched(ctx, wrapped) + return + } + + if route.Type == RouteHTTP { + r.sendToHTTP(wrapped) + return + } + + if err := r.relayTCP(ctx, wrapped, host, route); err != nil { + r.logger.WithFields(log.Fields{ + "sni": host, + "service_id": route.ServiceID, + "target": route.Target, + }).Warnf("TCP relay: %v", err) + _ = wrapped.Close() + } +} + +// isFallbackOnly returns true when the router has no SNI routes and no HTTP +// channel, meaning all connections should go directly to the fallback relay. +func (r *Router) isFallbackOnly() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.httpCh == nil +} + +// handleUnmatched routes a connection that didn't match any SNI route. +// This includes ECH/ESNI connections where the cleartext SNI is empty. +// It tries the fallback relay first, then the HTTP channel, and closes +// the connection if neither is available. +func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { + r.mu.RLock() + fb := r.fallback + r.mu.RUnlock() + + if fb != nil { + if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { + r.logger.WithFields(log.Fields{ + "service_id": fb.ServiceID, + "target": fb.Target, + }).Warnf("TCP relay (fallback): %v", err) + _ = conn.Close() + } + return + } + r.sendToHTTP(conn) +} + +// lookupRoute returns the highest-priority route for the given SNI host. +// HTTP routes take precedence over TCP routes. +func (r *Router) lookupRoute(host SNIHost) (Route, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + routes, ok := r.routes[host] + if !ok || len(routes) == 0 { + return Route{}, false + } + best := routes[0] + for _, route := range routes[1:] { + if route.Type < best.Type { + best = route + } + } + return best, true +} + +// sendToHTTP feeds the connection to the HTTP handler via the channel. +// If no HTTP channel is configured (port router), the router is +// draining, or the channel is full, the connection is closed. +func (r *Router) sendToHTTP(conn net.Conn) { + if r.httpCh == nil { + _ = conn.Close() + return + } + + r.mu.RLock() + draining := r.draining + r.mu.RUnlock() + + if draining { + _ = conn.Close() + return + } + + select { + case r.httpCh <- conn: + default: + r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr()) + _ = conn.Close() + } +} + +// Drain prevents new relay connections from starting and waits for all +// in-flight connection handlers and active relays to finish, up to the +// given timeout. Returns true if all completed, false on timeout. +func (r *Router) Drain(timeout time.Duration) bool { + r.mu.Lock() + r.draining = true + if r.drainDone == nil { + done := make(chan struct{}) + go func() { + r.activeConns.Wait() + r.activeRelays.Wait() + close(done) + }() + r.drainDone = done + } + done := r.drainDone + r.mu.Unlock() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// cancelServiceLocked cancels and removes the context for the given service, +// closing all its active relay connections. Must be called with mu held. +func (r *Router) cancelServiceLocked(svcID types.ServiceID) { + if cancel, ok := r.svcCancels[svcID]; ok { + cancel() + delete(r.svcCtxs, svcID) + delete(r.svcCancels, svcID) + } +} + +// relayTCP sets up and runs a bidirectional TCP relay. +// The caller owns conn and must close it if this method returns an error. +// On success (nil error), both conn and backend are closed by the relay. +func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { + svcCtx, err := r.acquireRelay(ctx, route) + if err != nil { + return err + } + defer func() { + <-r.relaySem + r.activeRelays.Done() + }() + + backend, err := r.dialBackend(svcCtx, route) + if err != nil { + obs := r.getObserver() + if obs != nil { + obs.TCPRelayDialError(route.AccountID) + } + return err + } + + if route.ProxyProtocol { + if err := writeProxyProtoV2(conn, backend); err != nil { + _ = backend.Close() + return fmt.Errorf("write PROXY protocol header: %w", err) + } + } + + obs := r.getObserver() + if obs != nil { + obs.TCPRelayStarted(route.AccountID) + } + + entry := r.logger.WithFields(log.Fields{ + "sni": sni, + "service_id": route.ServiceID, + "target": route.Target, + }) + entry.Debug("TCP relay started") + + start := time.Now() + s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout) + elapsed := time.Since(start) + + if obs != nil { + obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s) + } + entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s) + + r.logL4Entry(route, conn, elapsed, s2d, d2s) + return nil +} + +// acquireRelay checks draining state, increments activeRelays, and acquires +// a semaphore slot. Returns the per-service context on success. +// The caller must release the semaphore and call activeRelays.Done() when done. +func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) { + r.mu.Lock() + if r.draining { + r.mu.Unlock() + return nil, errors.New("router is draining") + } + r.activeRelays.Add(1) + svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID) + r.mu.Unlock() + + select { + case r.relaySem <- struct{}{}: + return svcCtx, nil + default: + r.activeRelays.Done() + obs := r.getObserver() + if obs != nil { + obs.TCPRelayRejected(route.AccountID) + } + return nil, errors.New("TCP relay connection limit reached") + } +} + +// dialBackend resolves the dialer for the route's account and dials the backend. +func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) { + dialFn, err := r.dialResolve(route.AccountID) + if err != nil { + return nil, fmt.Errorf("resolve dialer: %w", err) + } + + dialTimeout := route.DialTimeout + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout) + backend, err := dialFn(dialCtx, "tcp", route.Target) + dialCancel() + if err != nil { + return nil, fmt.Errorf("dial backend %s: %w", route.Target, err) + } + return backend, nil +} + +// logL4Entry sends a TCP relay access log entry if an access logger is configured. +func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + var sourceIP netip.Addr + if remote := conn.RemoteAddr(); remote != nil { + if ap, err := netip.ParseAddrPort(remote.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + } + + al.LogL4(accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DurationMs: duration.Milliseconds(), + BytesUpload: bytesUp, + BytesDownload: bytesDown, + }) +} + +// getOrCreateServiceCtxLocked returns the context for a service, creating one +// if it doesn't exist yet. The context is a child of the server context. +// Must be called with mu held. +func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context { + if ctx, ok := r.svcCtxs[svcID]; ok { + return ctx + } + ctx, cancel := context.WithCancel(parent) + r.svcCtxs[svcID] = ctx + r.svcCancels[svcID] = cancel + return ctx +} diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go new file mode 100644 index 000000000..0e2cfe3e1 --- /dev/null +++ b/proxy/internal/tcp/router_test.go @@ -0,0 +1,1670 @@ +package tcp + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRouter_HTTPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Dial in a goroutine. The TLS handshake will block since nothing + // completes it on the HTTP side, but we only care about routing. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + // Send a TLS ClientHello manually. + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + // Verify the connection was routed to the HTTP channel. + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("no connection received on HTTP channel") + } +} + +func TestRouter_TCPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Set up a TLS backend that the relay will connect to. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + backendAddr := backendLn.Addr().String() + + // Accept one connection on the backend, echo data back. + backendReady := make(chan struct{}) + go func() { + close(backendReady) + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + <-backendReady + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendAddr, + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Connect as a TLS client; the proxy should passthrough to the backend. + clientConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer clientConn.Close() + + testData := []byte("hello through TCP passthrough") + _, err = clientConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := clientConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through TCP passthrough") +} + +func TestRouter_UnknownSNIGoesToHTTP(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + // No routes registered. + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("unknown SNI should be routed to HTTP") + } +} + +// TestRouter_NonTLSConnectionDropped verifies that a non-TLS connection +// on the shared port is closed by the router (SNI peek fails to find a +// valid ClientHello, so there is no route match). +func TestRouter_NonTLSConnectionDropped(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Register a TLS passthrough route. Non-TLS should NOT match. + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: "127.0.0.1:9999", + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Send plain HTTP (non-TLS) data. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: tcp.example.com\r\n\r\n")) + + // Non-TLS traffic on a port with RouteTCP goes to the HTTP channel + // because there's no valid SNI to match. Verify it reaches HTTP. + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "non-TLS connection should fall through to HTTP") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("non-TLS connection was not routed to HTTP") + } +} + +// TestRouter_TLSAndHTTPCoexist verifies that a shared port with both HTTP +// and TLS passthrough routes correctly demuxes based on the SNI hostname. +func TestRouter_TLSAndHTTPCoexist(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes data. + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + // HTTP route. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // TLS passthrough route. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // 1. TLS connection with SNI "tcp.example.com" → TLS passthrough. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + testData := []byte("passthrough data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "TLS passthrough should relay data") + tlsConn.Close() + + // 2. TLS connection with SNI "app.example.com" → HTTP handler. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "HTTP SNI should go to HTTP handler") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } +} + +func TestRouter_AddRemoveRoute(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.1:5432"}) + + route, ok := router.lookupRoute("a.example.com") + assert.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type) + + route, ok = router.lookupRoute("b.example.com") + assert.True(t, ok) + assert.Equal(t, RouteTCP, route.Type) + + router.RemoveRoute("a.example.com", "svc-a") + _, ok = router.lookupRoute("a.example.com") + assert.False(t, ok) +} + +func TestChanListener_AcceptAndClose(t *testing.T) { + ch := make(chan net.Conn, 1) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + ln := newChanListener(ch, addr) + + assert.Equal(t, addr, ln.Addr()) + + // Send a connection. + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + ch <- serverConn + + conn, err := ln.Accept() + require.NoError(t, err) + assert.Equal(t, serverConn, conn) + + // Close should cause Accept to return error. + require.NoError(t, ln.Close()) + // Double close should be safe. + require.NoError(t, ln.Close()) + + _, err = ln.Accept() + assert.ErrorIs(t, err, net.ErrClosed) +} + +func TestRouter_HTTPPrecedenceGuard(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + host := SNIHost("app.example.com") + + t.Run("http takes precedence over tcp at lookup", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP route must take precedence over TCP") + assert.Equal(t, types.ServiceID("svc-http"), route.ServiceID) + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("tcp becomes active when http is removed", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + router.RemoveRoute(host, "svc-http") + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "TCP should take over after HTTP removal") + assert.Equal(t, types.ServiceID("svc-tcp"), route.ServiceID) + + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("order of add does not matter", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP takes precedence regardless of add order") + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("same service id updates in place", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.2:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, "10.0.0.2:443", route.Target, "route should be updated in place") + + router.RemoveRoute(host, "svc-1") + _, ok = router.lookupRoute(host) + assert.False(t, ok) + }) + + t.Run("double remove is safe", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-1"}) + router.RemoveRoute(host, "svc-1") + router.RemoveRoute(host, "svc-1") + + _, ok := router.lookupRoute(host) + assert.False(t, ok, "route should be gone after removal") + }) + + t.Run("remove does not affect other hosts", func(t *testing.T) { + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.2:22"}) + + router.RemoveRoute("a.example.com", "svc-a") + + _, ok := router.lookupRoute(SNIHost("a.example.com")) + assert.False(t, ok) + + route, ok := router.lookupRoute(SNIHost("b.example.com")) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "removing one host must not affect another") + + router.RemoveRoute("b.example.com", "svc-b") + }) +} + +func TestRouter_SetRemoveFallback(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new port router should be empty") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb", Target: "10.0.0.1:5432"}) + assert.False(t, router.IsEmpty(), "router with fallback should not be empty") + + router.AddRoute("a.example.com", Route{Type: RouteTCP, ServiceID: "svc-a", Target: "10.0.0.2:443"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb") + assert.False(t, router.IsEmpty(), "router with SNI route should not be empty") + + router.RemoveRoute("a.example.com", "svc-a") + assert.True(t, router.IsEmpty(), "router with no routes and no fallback should be empty") +} + +func TestPortRouter_FallbackRelaysData(t *testing.T) { + // Backend echo server. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Plain TCP (non-TLS) connection should be relayed via fallback. + // Use exactly 5 bytes. PeekClientHello reads 5 bytes as the TLS + // header, so a single 5-byte write lands as one chunk at the backend. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("hello") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through fallback relay") +} + +func TestPortRouter_FallbackOnUnknownSNI(t *testing.T) { + // Backend TLS echo server. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + // Only a fallback, no SNI route for "unknown.example.com". + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with unknown SNI → fallback relay to TLS backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("hello through fallback TLS") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "unknown SNI should relay through fallback") +} + +func TestPortRouter_SNIWinsOverFallback(t *testing.T) { + // Two backend echo servers: one for SNI match, one for fallback. + sniBacked := startEchoTLS(t) + fbBacked := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "sni-service", + Target: sniBacked.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBacked.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with matching SNI should go to SNI backend, not fallback. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("SNI route data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "SNI match should use SNI route, not fallback") +} + +func TestPortRouter_NoFallbackNoHTTP_Closes(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("hello")) + + // Connection should be closed by the router (no fallback, no HTTP). + buf := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be closed when no fallback and no HTTP channel") +} + +func TestRouter_FallbackAndHTTPCoexist(t *testing.T) { + // Fallback backend echo server (plain TCP). + fbBackend, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer fbBackend.Close() + + go func() { + conn, err := fbBackend.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + + // HTTP route for known SNI. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // Fallback for non-TLS / unknown SNI. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with known HTTP SNI → should go to HTTP channel. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "known HTTP SNI should go to HTTP channel") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } + + // 2. Plain TCP (non-TLS) → should go to fallback, not HTTP. + // Use exactly 5 bytes to match PeekClientHello header size. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("plain") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "non-TLS should be relayed via fallback, not HTTP") +} + +// startEchoTLS starts a TLS echo server and returns the listener. +func startEchoTLS(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + if _, err := conn.Write(buf[:n]); err != nil { + return + } + } + }() + + return ln +} + +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"tcp.example.com"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + } +} + +func TestRouter_DrainWaitsForRelays(t *testing.T) { + logger := log.StandardLogger() + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Accept connections: echo first message, then hold open until told to close. + closeBackend := make(chan struct{}) + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + <-closeBackend + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func() { + _ = router.Serve(ctx, ln) + close(serveDone) + }() + + // Open a relay connection (non-TLS, hits fallback). + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + _, _ = conn.Write([]byte("hello")) + + // Wait for the echo to confirm the relay is fully established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Drain with a short timeout should fail because the relay is still active. + assert.False(t, router.Drain(50*time.Millisecond), "drain should timeout with active relay") + + // Close backend connections so relays finish. + close(closeBackend) + _ = conn.Close() + + // Drain should now complete quickly. + assert.True(t, router.Drain(2*time.Second), "drain should succeed after relays end") + + cancel() + <-serveDone +} + +func TestRouter_DrainEmptyReturnsImmediately(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + start := time.Now() + ok := router.Drain(5 * time.Second) + elapsed := time.Since(start) + + assert.True(t, ok) + assert.Less(t, elapsed, 100*time.Millisecond, "drain with no relays should return immediately") +} + +// TestRemoveRoute_KillsActiveRelays verifies that removing a route +// immediately kills active relay connections for that service. +func TestRemoveRoute_KillsActiveRelays(t *testing.T) { + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes first message, then holds connection open. + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + // Hold the connection open. + for { + if _, err := c.Read(buf); err != nil { + return + } + } + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-1", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a relay connection. + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("hello")) + require.NoError(t, err) + + // Wait for echo to confirm relay is established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Remove the fallback: should kill the active relay. + router.RemoveFallback("svc-1") + + // The client connection should see an error (server closed). + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be killed after service removal") +} + +// TestRemoveRoute_KillsSNIRelays verifies that removing an SNI route +// kills its active relays without affecting other services. +func TestRemoveRoute_KillsSNIRelays(t *testing.T) { + backend := startEchoTLS(t) + defer backend.Close() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tls.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-tls", + Target: backend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a TLS relay. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + defer tlsConn.Close() + + _, err = tlsConn.Write([]byte("ping")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Remove the route: active relay should die. + router.RemoveRoute("tls.example.com", "svc-tls") + + _ = tlsConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = tlsConn.Read(buf) + assert.Error(t, err, "TLS relay should be killed after route removal") +} + +// TestPortRouter_SNIAndTCPFallbackCoexist verifies that a single port can +// serve both SNI-routed TLS passthrough and plain TCP fallback simultaneously. +func TestPortRouter_SNIAndTCPFallbackCoexist(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + // SNI route for a specific domain. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-sni", + Target: sniBackend.Addr().String(), + }) + // TCP fallback for everything else. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with matching SNI → goes to SNI backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = tlsConn.Write([]byte("sni-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-data", string(buf[:n]), "SNI match → SNI backend") + tlsConn.Close() + + // 2. Plain TCP (no TLS) → goes to fallback. + tcpConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + + _, err = tcpConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = tcpConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n]), "plain TCP → fallback backend") + tcpConn.Close() + + // 3. TLS with unknown SNI → also goes to fallback. + unknownBackend := startEchoTLS(t) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: unknownBackend.Addr().String(), + }) + + unknownTLS, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = unknownTLS.Write([]byte("unknown-sni")) + require.NoError(t, err) + n, err = unknownTLS.Read(buf) + require.NoError(t, err) + assert.Equal(t, "unknown-sni", string(buf[:n]), "unknown SNI → fallback backend") + unknownTLS.Close() +} + +// TestPortRouter_UpdateRouteSwapsSNI verifies that updating a route +// (remove + add with different target) correctly routes to the new backend. +func TestPortRouter_UpdateRouteSwapsSNI(t *testing.T) { + backend1 := startEchoTLS(t) + backend2 := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Initial route → backend1. + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend1.Addr().String(), + }) + + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("v1")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v1", string(buf[:n])) + conn1.Close() + + // Update: remove old route, add new → backend2. + router.RemoveRoute("db.example.com", "svc-db") + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend2.Addr().String(), + }) + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("v2")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v2", string(buf[:n])) + conn2.Close() +} + +// TestPortRouter_RemoveSNIFallsThrough verifies that after removing an +// SNI route, connections for that domain fall through to the fallback. +func TestPortRouter_RemoveSNIFallsThrough(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: sniBackend.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Before removal: SNI matches → sniBackend. + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("before")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "before", string(buf[:n])) + conn1.Close() + + // Remove SNI route. Should fall through to fallback. + router.RemoveRoute("db.example.com", "svc-db") + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("after")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n]), "after removal, should reach fallback") + conn2.Close() +} + +// TestPortRouter_RemoveFallbackCloses verifies that after removing the +// fallback, non-matching connections are closed. +func TestPortRouter_RemoveFallbackCloses(t *testing.T) { + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // With fallback: plain TCP works. + conn1, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = conn1.Write([]byte("hello")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + conn1.Close() + + // Remove fallback. + router.RemoveFallback("svc-fb") + + // Without fallback on a port router (no HTTP channel): connection should be closed. + conn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn2.Close() + _, _ = conn2.Write([]byte("bye")) + _ = conn2.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn2.Read(buf) + assert.Error(t, err, "without fallback, connection should be closed") +} + +// TestPortRouter_HTTPToTLSTransition verifies that switching a service from +// HTTP-only to TLS-only via remove+add doesn't orphan the old HTTP route. +func TestPortRouter_HTTPToTLSTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: HTTP-only. SNI connections go to HTTP channel. + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener did not receive connection for HTTP-only route") + } + + // Phase 2: Simulate update to TLS-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS connection should succeed after HTTP→TLS transition") + defer tlsConn2.Close() + + _, err = tlsConn2.Write([]byte("hello-tls")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello-tls", string(buf[:n]), "data should relay to TLS backend") +} + +// TestPortRouter_TLSToHTTPTransition verifies that switching a service from +// TLS-only to HTTP-only via remove+add doesn't orphan the old TLS route. +func TestPortRouter_TLSToHTTPTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: TLS-only. Route relays to backend. + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS relay should work before transition") + _, err = tlsConn.Write([]byte("tls-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "tls-data", string(buf[:n])) + tlsConn.Close() + + // Phase 2: Simulate update to HTTP-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + // TLS connection should now go to the HTTP listener, NOT to the old TLS backend. + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn2.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener should receive connection after TLS→HTTP transition") + } +} + +// TestPortRouter_MultiDomainSamePort verifies that two TLS services sharing +// the same port router are independently routable and removable. +func TestPortRouter_MultiDomainSamePort(t *testing.T) { + logger := log.StandardLogger() + backend1 := startEchoTLSMulti(t) + backend2 := startEchoTLSMulti(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + router.AddRoute("svc1.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: backend1.Addr().String()}) + router.AddRoute("svc2.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-2", Target: backend2.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // Both domains route independently. + for _, tc := range []struct { + sni string + data string + }{ + {"svc1.example.com", "hello-svc1"}, + {"svc2.example.com", "hello-svc2"}, + } { + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: tc.sni, InsecureSkipVerify: true}, + ) + require.NoError(t, err, "dial %s", tc.sni) + _, err = conn.Write([]byte(tc.data)) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, tc.data, string(buf[:n])) + conn.Close() + } + + // Remove svc1. Router should NOT be empty (svc2 still present). + router.RemoveRoute("svc1.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "router should not be empty with one route remaining") + + // svc2 still works. + conn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "svc2.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = conn2.Write([]byte("still-alive")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "still-alive", string(buf[:n])) + conn2.Close() + + // Remove svc2. Router is now empty. + router.RemoveRoute("svc2.example.com", "svc-2") + assert.True(t, router.IsEmpty(), "router should be empty after removing all routes") +} + +// TestPortRouter_SNIAndFallbackLifecycle verifies the full lifecycle of SNI +// routes and TCP fallback coexisting on the same port router, including the +// ordering of add/remove operations. +func TestPortRouter_SNIAndFallbackLifecycle(t *testing.T) { + logger := log.StandardLogger() + sniBackend := startEchoTLS(t) + fallbackBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Step 1: Add fallback first (port mapping), then SNI route (TLS service). + router.SetFallback(Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "pm-1", Target: fallbackBackend.Addr().String()}) + router.AddRoute("tls.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: sniBackend.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // SNI traffic goes to TLS backend. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = tlsConn.Write([]byte("sni-traffic")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-traffic", string(buf[:n])) + tlsConn.Close() + + // Plain TCP goes to fallback. + plainConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = plainConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = plainConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n])) + plainConn.Close() + + // Step 2: Remove SNI route. Fallback still works, router not empty. + router.RemoveRoute("tls.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "fallback still present") + + plainConn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + // Must send >= 5 bytes so the SNI peek completes immediately + // without waiting for the 5-second peek timeout. + _, err = plainConn2.Write([]byte("after")) + require.NoError(t, err) + n, err = plainConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n])) + plainConn2.Close() + + // Step 3: Remove fallback. Router is now empty. + router.RemoveFallback("pm-1") + assert.True(t, router.IsEmpty()) +} + +// TestPortRouter_IsEmptyTransitions verifies IsEmpty reflects correct state +// through all add/remove operations. +func TestPortRouter_IsEmptyTransitions(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new router") + + router.AddRoute("a.com", Route{Type: RouteTCP, ServiceID: "svc-a"}) + assert.False(t, router.IsEmpty(), "after adding route") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb1"}) + assert.False(t, router.IsEmpty(), "route + fallback") + + router.RemoveRoute("a.com", "svc-a") + assert.False(t, router.IsEmpty(), "fallback only") + + router.RemoveFallback("svc-fb1") + assert.True(t, router.IsEmpty(), "all removed") + + // Reverse order: fallback first, then route. + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb2"}) + assert.False(t, router.IsEmpty()) + + router.AddRoute("b.com", Route{Type: RouteTCP, ServiceID: "svc-b"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb2") + assert.False(t, router.IsEmpty(), "route still present") + + router.RemoveRoute("b.com", "svc-b") + assert.True(t, router.IsEmpty(), "fully empty again") +} + +// startEchoTLSMulti starts a TLS echo server that accepts multiple connections. +func startEchoTLSMulti(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} + +// startEchoPlain starts a plain TCP echo server that reads until newline +// or connection close, then echoes the received data. +func startEchoPlain(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + // Set a read deadline so we don't block forever waiting for more data. + _ = c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} diff --git a/proxy/internal/tcp/snipeek.go b/proxy/internal/tcp/snipeek.go new file mode 100644 index 000000000..25ab8e5ef --- /dev/null +++ b/proxy/internal/tcp/snipeek.go @@ -0,0 +1,191 @@ +package tcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" +) + +const ( + // TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2). + tlsRecordHeaderLen = 5 + // TLS handshake type for ClientHello. + handshakeTypeClientHello = 1 + // TLS ContentType for handshake messages. + contentTypeHandshake = 22 + // SNI extension type (RFC 6066). + extensionServerName = 0 + // SNI host name type. + sniHostNameType = 0 + // maxClientHelloLen caps the ClientHello size we're willing to buffer. + maxClientHelloLen = 16384 + // maxSNILen is the maximum valid DNS hostname length per RFC 1035. + maxSNILen = 253 +) + +// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI +// server name, and returns a wrapped connection that replays the peeked +// bytes transparently. If the data is not a valid TLS ClientHello or +// contains no SNI extension, sni is empty and err is nil. +// +// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the +// real server name is encrypted inside the encrypted_client_hello +// extension. This parser only reads the cleartext server_name extension +// (type 0x0000), so ECH connections return sni="" and are routed through +// the fallback path (or HTTP channel), which is the correct behavior +// for a transparent proxy that does not terminate TLS. +func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { + // Read the 5-byte TLS record header into a small stack-friendly buffer. + var header [tlsRecordHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return "", nil, fmt.Errorf("read TLS record header: %w", err) + } + + if header[0] != contentTypeHandshake { + return "", newPeekedConn(conn, header[:]), nil + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + if recordLen == 0 || recordLen > maxClientHelloLen { + return "", newPeekedConn(conn, header[:]), nil + } + + // Single allocation for header + payload. The peekedConn takes + // ownership of this buffer, so no further copies are needed. + buf := make([]byte, tlsRecordHeaderLen+recordLen) + copy(buf, header[:]) + + n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:]) + if err != nil { + return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err) + } + + sni = extractSNI(buf[tlsRecordHeaderLen:]) + return sni, newPeekedConn(conn, buf), nil +} + +// extractSNI parses a TLS handshake payload to find the SNI extension. +// Returns empty string if the payload is not a ClientHello or has no SNI. +func extractSNI(payload []byte) string { + if len(payload) < 4 { + return "" + } + + if payload[0] != handshakeTypeClientHello { + return "" + } + + // Handshake length (3 bytes, big-endian). + handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3]) + if handshakeLen > len(payload)-4 { + return "" + } + + return parseSNIFromClientHello(payload[4 : 4+handshakeLen]) +} + +// parseSNIFromClientHello walks the ClientHello message fields to reach +// the extensions block and extract the server_name extension value. +func parseSNIFromClientHello(msg []byte) string { + // ClientHello layout: + // ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id + if len(msg) < 34 { + return "" + } + + pos := 34 + + // Session ID (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + sessionIDLen := int(msg[pos]) + pos++ + pos += sessionIDLen + + // Cipher suites (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + cipherSuitesLen + + // Compression methods (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + compMethodsLen := int(msg[pos]) + pos++ + pos += compMethodsLen + + // Extensions (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + + extensionsEnd := pos + extensionsLen + if extensionsEnd > len(msg) { + return "" + } + + return findSNIExtension(msg[pos:extensionsEnd]) +} + +// findSNIExtension iterates over TLS extensions and returns the host +// name from the server_name extension, if present. +func findSNIExtension(extensions []byte) string { + pos := 0 + for pos+4 <= len(extensions) { + extType := binary.BigEndian.Uint16(extensions[pos : pos+2]) + extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4])) + pos += 4 + + if pos+extLen > len(extensions) { + return "" + } + + if extType == extensionServerName { + return parseSNIExtensionData(extensions[pos : pos+extLen]) + } + pos += extLen + } + return "" +} + +// parseSNIExtensionData parses the ServerNameList structure inside an +// SNI extension to extract the host name. +func parseSNIExtensionData(data []byte) string { + if len(data) < 2 { + return "" + } + listLen := int(binary.BigEndian.Uint16(data[0:2])) + if listLen > len(data)-2 { + return "" + } + + list := data[2 : 2+listLen] + pos := 0 + for pos+3 <= len(list) { + nameType := list[pos] + nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3])) + pos += 3 + + if pos+nameLen > len(list) { + return "" + } + + if nameType == sniHostNameType { + name := list[pos : pos+nameLen] + if nameLen > maxSNILen || bytes.ContainsRune(name, 0) { + return "" + } + return string(name) + } + pos += nameLen + } + return "" +} diff --git a/proxy/internal/tcp/snipeek_test.go b/proxy/internal/tcp/snipeek_test.go new file mode 100644 index 000000000..9afe6261d --- /dev/null +++ b/proxy/internal/tcp/snipeek_test.go @@ -0,0 +1,251 @@ +package tcp + +import ( + "crypto/tls" + "io" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPeekClientHello_ValidSNI(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + const expectedSNI = "example.com" + trailingData := []byte("trailing data after handshake") + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: expectedSNI, + InsecureSkipVerify: true, //nolint:gosec + }) + // The Handshake will send the ClientHello. It will fail because + // our server side isn't doing a real TLS handshake, but that's + // fine: we only need the ClientHello to be sent. + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello") + assert.NotNil(t, wrapped, "wrapped connection should not be nil") + + // Verify the wrapped connection replays the peeked bytes. + // Read the first 5 bytes (TLS record header) to confirm replay. + buf := make([]byte, 5) + n, err := wrapped.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type") + + // Write trailing data from the client side and verify it arrives + // through the wrapped connection after the peeked bytes. + go func() { + _, _ = clientConn.Write(trailingData) + }() + + // Drain the rest of the peeked ClientHello first. + peekedRest := make([]byte, 16384) + _, _ = wrapped.Read(peekedRest) + + got := make([]byte, len(trailingData)) + n, err = io.ReadFull(wrapped, got) + require.NoError(t, err) + assert.Equal(t, trailingData, got[:n]) +} + +func TestPeekClientHello_MultipleSNIs(t *testing.T) { + tests := []struct { + name string + serverName string + expectedSNI string + }{ + {"simple domain", "example.com", "example.com"}, + {"subdomain", "sub.example.com", "sub.example.com"}, + {"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: tt.serverName, + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, tt.expectedSNI, sni) + assert.NotNil(t, wrapped) + }) + } +} + +func TestPeekClientHello_NonTLSData(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // Send plain HTTP data (not TLS). + httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + go func() { + _, _ = clientConn.Write(httpData) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni, "should return empty SNI for non-TLS data") + assert.NotNil(t, wrapped) + + // Verify the wrapped connection still provides the original data. + buf := make([]byte, len(httpData)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data") +} + +func TestPeekClientHello_TruncatedHeader(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write only 3 bytes then close, fewer than the 5-byte TLS header. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01}) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated header") +} + +func TestPeekClientHello_TruncatedPayload(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write a valid TLS header claiming 100 bytes, but only send 10. + go func() { + header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed + _, _ = clientConn.Write(header) + _, _ = clientConn.Write(make([]byte, 10)) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated payload") +} + +func TestPeekClientHello_ZeroLengthRecord(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // TLS handshake header with zero-length payload. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00}) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni) + assert.NotNil(t, wrapped) +} + +func TestExtractSNI_InvalidPayload(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + {"nil", nil}, + {"empty", []byte{}}, + {"too short", []byte{0x01, 0x00}}, + {"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}}, + {"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Empty(t, extractSNI(tt.payload)) + }) + } +} + +func TestPeekedConn_CloseWrite(t *testing.T) { + t.Run("delegates to underlying TCPConn", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accepted <- c + } + }() + + client, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer client.Close() + + server := <-accepted + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + + // CloseWrite should succeed on a real TCP connection. + err = wrapped.CloseWrite() + assert.NoError(t, err) + + // The client should see EOF on reads after CloseWrite. + buf := make([]byte, 1) + _, err = client.Read(buf) + assert.Equal(t, io.EOF, err, "client should see EOF after half-close") + }) + + t.Run("no-op on non-halfcloser", func(t *testing.T) { + // net.Pipe does not implement CloseWrite. + _, server := net.Pipe() + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + err := wrapped.CloseWrite() + assert.NoError(t, err, "should be no-op on non-halfcloser") + }) +} + +func TestPeekedConn_ReplayAndPassthrough(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + peeked := []byte("peeked-data") + subsequent := []byte("subsequent-data") + + wrapped := newPeekedConn(serverConn, peeked) + + go func() { + _, _ = clientConn.Write(subsequent) + }() + + // Read should return peeked data first. + buf := make([]byte, len(peeked)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, peeked, buf[:n]) + + // Then subsequent data from the real connection. + buf = make([]byte, len(subsequent)) + n, err = io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, subsequent, buf[:n]) +} diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go index 41acfef40..bf3731803 100644 --- a/proxy/internal/types/types.go +++ b/proxy/internal/types/types.go @@ -1,5 +1,56 @@ // Package types defines common types used across the proxy package. package types +import ( + "context" + "net" + "time" +) + // AccountID represents a unique identifier for a NetBird account. type AccountID string + +// ServiceID represents a unique identifier for a proxy service. +type ServiceID string + +// ServiceMode describes how a reverse proxy service is exposed. +type ServiceMode string + +const ( + ServiceModeHTTP ServiceMode = "http" + ServiceModeTCP ServiceMode = "tcp" + ServiceModeUDP ServiceMode = "udp" + ServiceModeTLS ServiceMode = "tls" +) + +// IsL4 returns true for TCP, UDP, and TLS modes. +func (m ServiceMode) IsL4() bool { + return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS +} + +// RelayDirection indicates the direction of a relayed packet. +type RelayDirection string + +const ( + RelayDirectionClientToBackend RelayDirection = "client_to_backend" + RelayDirectionBackendToClient RelayDirection = "backend_to_client" +) + +// DialContextFunc dials a backend through the WireGuard tunnel. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// dialTimeoutKey is the context key for a per-request dial timeout. +type dialTimeoutKey struct{} + +// WithDialTimeout returns a context carrying a dial timeout that +// DialContext wrappers can use to scope the timeout to just the +// connection establishment phase. +func WithDialTimeout(ctx context.Context, d time.Duration) context.Context { + return context.WithValue(ctx, dialTimeoutKey{}, d) +} + +// DialTimeoutFromContext returns the dial timeout from the context, if set. +func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) { + d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration) + return d, ok && d > 0 +} diff --git a/proxy/internal/types/types_test.go b/proxy/internal/types/types_test.go new file mode 100644 index 000000000..dd9738442 --- /dev/null +++ b/proxy/internal/types/types_test.go @@ -0,0 +1,54 @@ +package types + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestServiceMode_IsL4(t *testing.T) { + tests := []struct { + mode ServiceMode + want bool + }{ + {ServiceModeHTTP, false}, + {ServiceModeTCP, true}, + {ServiceModeUDP, true}, + {ServiceModeTLS, true}, + {ServiceMode("unknown"), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + assert.Equal(t, tt.want, tt.mode.IsL4()) + }) + } +} + +func TestDialTimeoutContext(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 5*time.Second) + d, ok := DialTimeoutFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, 5*time.Second, d) + }) + + t.Run("missing", func(t *testing.T) { + _, ok := DialTimeoutFromContext(context.Background()) + assert.False(t, ok) + }) + + t.Run("zero returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 0) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "zero duration should return ok=false") + }) + + t.Run("negative returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), -1*time.Second) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "negative duration should return ok=false") + }) +} diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go new file mode 100644 index 000000000..f2f58e858 --- /dev/null +++ b/proxy/internal/udp/relay.go @@ -0,0 +1,496 @@ +package udp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +const ( + // DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup. + DefaultSessionTTL = 30 * time.Second + // cleanupInterval is how often the cleaner goroutine runs. + cleanupInterval = time.Minute + // maxPacketSize is the maximum UDP packet size we'll handle. + maxPacketSize = 65535 + // DefaultMaxSessions is the default cap on concurrent UDP sessions per relay. + DefaultMaxSessions = 1024 + // sessionCreateRate limits new session creation per second. + sessionCreateRate = 50 + // sessionCreateBurst is the burst allowance for session creation. + sessionCreateBurst = 100 + // defaultDialTimeout is the fallback dial timeout for backend connections. + defaultDialTimeout = 30 * time.Second +) + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// SessionObserver receives callbacks for UDP session lifecycle events. +// All methods must be safe for concurrent use. +type SessionObserver interface { + UDPSessionStarted(accountID types.AccountID) + UDPSessionEnded(accountID types.AccountID) + UDPSessionDialError(accountID types.AccountID) + UDPSessionRejected(accountID types.AccountID) + UDPPacketRelayed(direction types.RelayDirection, bytes int) +} + +// clientAddr is a typed key for UDP session lookups. +type clientAddr string + +// Relay listens for incoming UDP packets on a dedicated port and +// maintains per-client sessions that relay packets to a backend +// through the WireGuard tunnel. +type Relay struct { + logger *log.Entry + listener net.PacketConn + target string + domain string + accountID types.AccountID + serviceID types.ServiceID + dialFunc types.DialContextFunc + dialTimeout time.Duration + sessionTTL time.Duration + maxSessions int + + mu sync.RWMutex + sessions map[clientAddr]*session + + bufPool sync.Pool + sessLimiter *rate.Limiter + sessWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + observer SessionObserver + accessLog l4Logger +} + +type session struct { + backend net.Conn + addr net.Addr + createdAt time.Time + // lastSeen stores the last activity timestamp as unix nanoseconds. + lastSeen atomic.Int64 + cancel context.CancelFunc + // bytesIn tracks total bytes received from the client. + bytesIn atomic.Int64 + // bytesOut tracks total bytes sent back to the client. + bytesOut atomic.Int64 +} + +func (s *session) updateLastSeen() { + s.lastSeen.Store(time.Now().UnixNano()) +} + +func (s *session) idleDuration() time.Duration { + return time.Since(time.Unix(0, s.lastSeen.Load())) +} + +// RelayConfig holds the configuration for a UDP relay. +type RelayConfig struct { + Logger *log.Entry + Listener net.PacketConn + Target string + Domain string + AccountID types.AccountID + ServiceID types.ServiceID + DialFunc types.DialContextFunc + DialTimeout time.Duration + SessionTTL time.Duration + MaxSessions int + AccessLog l4Logger +} + +// New creates a UDP relay for the given listener and backend target. +// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions. +// DialTimeout controls how long to wait for backend connections; use 0 for default. +// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL. +func New(parentCtx context.Context, cfg RelayConfig) *Relay { + maxSessions := cfg.MaxSessions + dialTimeout := cfg.DialTimeout + sessionTTL := cfg.SessionTTL + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + if sessionTTL <= 0 { + sessionTTL = DefaultSessionTTL + } + ctx, cancel := context.WithCancel(parentCtx) + return &Relay{ + logger: cfg.Logger, + listener: cfg.Listener, + target: cfg.Target, + domain: cfg.Domain, + accountID: cfg.AccountID, + serviceID: cfg.ServiceID, + accessLog: cfg.AccessLog, + dialFunc: cfg.DialFunc, + dialTimeout: dialTimeout, + sessionTTL: sessionTTL, + maxSessions: maxSessions, + sessions: make(map[clientAddr]*session), + bufPool: sync.Pool{ + New: func() any { + buf := make([]byte, maxPacketSize) + return &buf + }, + }, + sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst), + ctx: ctx, + cancel: cancel, + } +} + +// ServiceID returns the service ID associated with this relay. +func (r *Relay) ServiceID() types.ServiceID { + return r.serviceID +} + +// SetObserver sets the session lifecycle observer. Must be called before Serve. +func (r *Relay) SetObserver(obs SessionObserver) { + r.observer = obs +} + +// Serve starts the relay loop. It blocks until the context is canceled +// or the listener is closed. +func (r *Relay) Serve() { + go r.cleanupLoop() + + for { + bufp := r.bufPool.Get().(*[]byte) + buf := *bufp + + n, addr, err := r.listener.ReadFrom(buf) + if err != nil { + r.bufPool.Put(bufp) + if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + return + } + r.logger.Debugf("UDP read: %v", err) + continue + } + + data := buf[:n] + sess, err := r.getOrCreateSession(addr) + if err != nil { + r.bufPool.Put(bufp) + r.logger.Debugf("create UDP session for %s: %v", addr, err) + continue + } + + sess.updateLastSeen() + + nw, err := sess.backend.Write(data) + if err != nil { + r.bufPool.Put(bufp) + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to backend for %s: %v", addr, err) + } + r.removeSession(sess) + continue + } + sess.bytesIn.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) + } + r.bufPool.Put(bufp) + } +} + +// getOrCreateSession returns an existing session or creates a new one. +func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { + key := clientAddr(addr.String()) + + r.mu.RLock() + sess, ok := r.sessions[key] + r.mu.RUnlock() + if ok && sess != nil { + return sess, nil + } + + // Check before taking the write lock: if the relay is shutting down, + // don't create new sessions. This prevents orphaned goroutines when + // Serve() processes a packet that was already read before Close(). + if r.ctx.Err() != nil { + return nil, r.ctx.Err() + } + + r.mu.Lock() + + if sess, ok = r.sessions[key]; ok && sess != nil { + r.mu.Unlock() + return sess, nil + } + if ok { + // Another goroutine is dialing for this key, skip. + r.mu.Unlock() + return nil, fmt.Errorf("session dial in progress for %s", key) + } + + if len(r.sessions) >= r.maxSessions { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) + } + + if !r.sessLimiter.Allow() { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session creation rate limited") + } + + // Reserve the slot with a nil session so concurrent callers for the same + // key see it exists and wait. Release the lock before dialing. + r.sessions[key] = nil + r.mu.Unlock() + + dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout) + backend, err := r.dialFunc(dialCtx, "udp", r.target) + dialCancel() + if err != nil { + r.mu.Lock() + delete(r.sessions, key) + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionDialError(r.accountID) + } + return nil, fmt.Errorf("dial backend %s: %w", r.target, err) + } + + sessCtx, sessCancel := context.WithCancel(r.ctx) + sess = &session{ + backend: backend, + addr: addr, + createdAt: time.Now(), + cancel: sessCancel, + } + sess.updateLastSeen() + + r.mu.Lock() + r.sessions[key] = sess + r.mu.Unlock() + + if r.observer != nil { + r.observer.UDPSessionStarted(r.accountID) + } + + r.sessWg.Go(func() { + r.relayBackendToClient(sessCtx, sess) + }) + + r.logger.Debugf("UDP session created for %s", addr) + return sess, nil +} + +// relayBackendToClient reads packets from the backend and writes them +// back to the client through the public-facing listener. +func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { + bufp := r.bufPool.Get().(*[]byte) + defer r.bufPool.Put(bufp) + defer r.removeSession(sess) + + for ctx.Err() == nil { + data, ok := r.readBackendPacket(sess, *bufp) + if !ok { + return + } + if data == nil { + continue + } + + sess.updateLastSeen() + + nw, err := r.listener.WriteTo(data, sess.addr) + if err != nil { + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to client %s: %v", sess.addr, err) + } + return + } + sess.bytesOut.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) + } + } +} + +// readBackendPacket reads one packet from the backend with an idle deadline. +// Returns (data, true) on success, (nil, true) on idle timeout that should +// retry, or (nil, false) when the session should be torn down. +func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) { + if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil { + r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err) + return nil, false + } + + n, err := sess.backend.Read(buf) + if err != nil { + if netutil.IsTimeout(err) { + if sess.idleDuration() > r.sessionTTL { + return nil, false + } + return nil, true + } + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err) + } + return nil, false + } + + return buf[:n], true +} + +// cleanupLoop periodically removes idle sessions. +func (r *Relay) cleanupLoop() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.cleanupIdleSessions() + } + } +} + +// cleanupIdleSessions closes sessions that have been idle for too long. +func (r *Relay) cleanupIdleSessions() { + var expired []*session + + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + continue + } + idle := sess.idleDuration() + if idle > r.sessionTTL { + r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load()) + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close idle session %s backend: %v", sess.addr, err) + } + expired = append(expired, sess) + } + } + r.mu.Unlock() + + for _, sess := range expired { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// removeSession removes a session from the map if it still matches the +// given pointer. This is safe to call concurrently with cleanupIdleSessions +// because the identity check prevents double-close when both paths race. +func (r *Relay) removeSession(sess *session) { + r.mu.Lock() + key := clientAddr(sess.addr.String()) + removed := r.sessions[key] == sess + if removed { + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + } + r.mu.Unlock() + + if removed { + r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// logSessionEnd sends an access log entry for a completed UDP session. +func (r *Relay) logSessionEnd(sess *session) { + if r.accessLog == nil { + return + } + + var sourceIP netip.Addr + if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + + r.accessLog.LogL4(accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: sourceIP, + DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(), + BytesUpload: sess.bytesIn.Load(), + BytesDownload: sess.bytesOut.Load(), + }) +} + +// Close stops the relay, waits for all session goroutines to exit, +// and cleans up remaining sessions. +func (r *Relay) Close() { + r.cancel() + if err := r.listener.Close(); err != nil { + r.logger.Debugf("close UDP listener: %v", err) + } + + var closedSessions []*session + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + delete(r.sessions, key) + continue + } + r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + delete(r.sessions, key) + closedSessions = append(closedSessions, sess) + } + r.mu.Unlock() + + for _, sess := range closedSessions { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } + + r.sessWg.Wait() +} diff --git a/proxy/internal/udp/relay_test.go b/proxy/internal/udp/relay_test.go new file mode 100644 index 000000000..a1e91b290 --- /dev/null +++ b/proxy/internal/udp/relay_test.go @@ -0,0 +1,493 @@ +package udp + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRelay_BasicPacketExchange(t *testing.T) { + // Set up a UDP backend that echoes packets. + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + // Set up the relay's public-facing listener. + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + backendAddr := backend.LocalAddr().String() + + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a client and send a packet to the relay. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + testData := []byte("hello UDP relay") + _, err = client.Write(testData) + require.NoError(t, err) + + // Read the echoed response. + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed packet") +} + +func TestRelay_MultipleClients(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Two clients, each should get their own session. + for i, msg := range []string{"client-1", "client-2"} { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte(msg)) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err, "client %d read", i) + assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i) + } + + // Verify two sessions were created. + relay.mu.RLock() + sessionCount := len(relay.sessions) + relay.mu.RUnlock() + assert.Equal(t, 2, sessionCount, "should have two sessions") +} + +func TestRelay_Close(t *testing.T) { + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc}) + + done := make(chan struct{}) + go func() { + relay.Serve() + close(done) + }() + + relay.Close() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Serve did not return after Close") + } +} + +func TestRelay_SessionCleanup(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a session. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + + // Verify session exists. + relay.mu.RLock() + assert.Equal(t, 1, len(relay.sessions)) + relay.mu.RUnlock() + + // Make session appear idle by setting lastSeen to the past. + relay.mu.Lock() + for _, sess := range relay.sessions { + sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano()) + } + relay.mu.Unlock() + + // Trigger cleanup manually. + relay.cleanupIdleSessions() + + relay.mu.RLock() + assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up") + relay.mu.RUnlock() +} + +// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new +// one on the same port works cleanly (simulates port mapping modify cycle). +func TestRelay_CloseAndRecreate(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // First relay. + ln1, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay1.Serve() + + client1, err := net.Dial("udp", ln1.LocalAddr().String()) + require.NoError(t, err) + _, err = client1.Write([]byte("relay1")) + require.NoError(t, err) + require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + n, err := client1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay1", string(buf[:n])) + client1.Close() + + // Close first relay. + relay1.Close() + + // Second relay on same port. + port := ln1.LocalAddr().(*net.UDPAddr).Port + ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + + relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay2.Serve() + defer relay2.Close() + + client2, err := net.Dial("udp", ln2.LocalAddr().String()) + require.NoError(t, err) + defer client2.Close() + _, err = client2.Write([]byte("relay2")) + require.NoError(t, err) + require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second))) + n, err = client2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port") +} + +func TestRelay_SessionLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // Create a relay with a max of 2 sessions. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2}) + go relay.Serve() + defer relay.Close() + + // Create 2 clients to fill up the session limit. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err, "client %d should get response", i) + } + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions") + relay.mu.RUnlock() + + // Third client should get its packet dropped (session creation fails). + client3, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client3.Close() + + _, err = client3.Write([]byte("should be dropped")) + require.NoError(t, err) + + require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + buf := make([]byte, 1024) + _, err = client3.Read(buf) + assert.Error(t, err, "third client should time out because session was rejected") + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit") + relay.mu.RUnlock() +} + +// testObserver records UDP session lifecycle events for test assertions. +type testObserver struct { + mu sync.Mutex + started int + ended int + rejected int + dialErr int + packets int + bytes int +} + +func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() } +func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() } +func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() } +func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() } +func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) { + o.mu.Lock() + o.packets++ + o.bytes += b + o.mu.Unlock() +} + +func TestRelay_CloseFiresObserverEnded(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc}) + relay.SetObserver(obs) + go relay.Serve() + + // Create two sessions. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + } + + obs.mu.Lock() + assert.Equal(t, 2, obs.started, "should have 2 started events") + obs.mu.Unlock() + + // Close should fire UDPSessionEnded for all remaining sessions. + relay.Close() + + obs.mu.Lock() + assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session") + obs.mu.Unlock() +} + +func TestRelay_SessionRateLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + // High max sessions (1000) but the relay uses a rate limiter internally + // (default: 50/s burst 100). We exhaust the burst by creating sessions + // rapidly, then verify that subsequent creates are rejected. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000}) + relay.SetObserver(obs) + go relay.Serve() + defer relay.Close() + + // Exhaust the burst by calling getOrCreateSession directly with + // synthetic addresses. This is faster than real UDP round-trips. + for i := range sessionCreateBurst + 20 { + addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i} + _, _ = relay.getOrCreateSession(addr) + } + + obs.mu.Lock() + rejected := obs.rejected + obs.mu.Unlock() + + assert.Greater(t, rejected, 0, "some sessions should be rate-limited") +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 6a0ecce30..ebecfc6f6 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -243,6 +243,10 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return nil +} + // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store @@ -505,15 +509,15 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T nil, "", 0, - mapping.GetAccountId(), - mapping.GetId(), + proxytypes.AccountID(mapping.GetAccountId()), + proxytypes.ServiceID(mapping.GetId()), ) require.NoError(t, err) // Apply to real proxy (idempotent) proxyHandler.AddMapping(proxy.Mapping{ Host: mapping.GetDomain(), - ID: mapping.GetId(), + ID: proxytypes.ServiceID(mapping.GetId()), AccountID: proxytypes.AccountID(mapping.GetAccountId()), }) } diff --git a/proxy/server.go b/proxy/server.go index 62e8368e6..649d49c9a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -30,6 +30,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/sdk/metric" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -46,15 +47,26 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" + udprelay "github.com/netbirdio/netbird/proxy/internal/udp" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) + +// portRouter bundles a per-port Router with its listener and cancel func. +type portRouter struct { + router *nbtcp.Router + listener net.Listener + cancel context.CancelFunc +} + type Server struct { mgmtClient proto.ProxyServiceClient proxy *proxy.ReverseProxy @@ -67,12 +79,27 @@ type Server struct { healthServer *health.Server healthChecker *health.Checker meter *proxymetrics.Metrics + accessLog *accesslog.Logger + mainRouter *nbtcp.Router + mainPort uint16 + udpMu sync.Mutex + udpRelays map[types.ServiceID]*udprelay.Relay + udpRelayWg sync.WaitGroup + portMu sync.RWMutex + portRouters map[uint16]*portRouter + svcPorts map[types.ServiceID][]uint16 + lastMappings map[types.ServiceID]*proto.ProxyMapping + portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker + // routerReady is closed once mainRouter is fully initialized. + // The mapping worker waits on this before processing updates. + routerReady chan struct{} + // Mostly used for debugging on management. startTime time.Time @@ -118,28 +145,36 @@ type Server struct { // When set, forwarding headers from these sources are preserved and // appended to instead of being stripped. TrustedProxies []netip.Prefix - // WireguardPort is the port for the WireGuard interface. Use 0 for a - // random OS-assigned port. A fixed port only works with single-account - // deployments; multiple accounts will fail to bind the same port. - WireguardPort int + // WireguardPort is the port for the NetBird tunnel interface. Use 0 + // for a random OS-assigned port. A fixed port only works with + // single-account deployments; multiple accounts will fail to bind + // the same port. + WireguardPort uint16 // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool // PreSharedKey used for tunnel between proxy and peers (set globally not per account) PreSharedKey string + // SupportsCustomPorts indicates whether the proxy can bind arbitrary + // ports for TCP/UDP/TLS services. + SupportsCustomPorts bool + // DefaultDialTimeout is the default timeout for establishing backend + // connections when no per-service timeout is configured. Zero means + // each transport uses its own hardcoded default (typically 30s). + DefaultDialTimeout time.Duration } -// NotifyStatus sends a status update to management about tunnel connectivity -func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error { +// NotifyStatus sends a status update to management about tunnel connectivity. +func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED if connected { status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: status, CertificateIssued: false, }) @@ -147,10 +182,10 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain } // NotifyCertificateIssued sends a notification to management that a certificate was issued -func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error { +func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error { _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, }) @@ -159,6 +194,11 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() + s.routerReady = make(chan struct{}) + s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) + s.portRouters = make(map[uint16]*portRouter) + s.svcPorts = make(map[types.ServiceID][]uint16) + s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) exporter, err := prometheus.New() if err != nil { @@ -184,7 +224,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) - go s.newManagementMappingWorker(ctx, s.mgmtClient) + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() + go s.newManagementMappingWorker(runCtx, s.mgmtClient) // Initialize the netbird client, this is required to build peer connections // to proxy over. @@ -206,7 +248,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) + s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.healthChecker = health.NewChecker(s.Logger, s.netbird) @@ -220,18 +262,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { handler := http.Handler(s.proxy) handler = s.auth.Protect(handler) handler = web.AssetHandler(handler) - handler = accessLog.Middleware(handler) + handler = s.accessLog.Middleware(handler) handler = s.meter.Middleware(handler) handler = s.hijackTracker.Middleware(handler) - // Start the reverse proxy HTTPS server. - s.https = &http.Server{ - Addr: addr, - Handler: handler, - TLSConfig: tlsConfig, - ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), - } - + // Start a raw TCP listener; the SNI router peeks at ClientHello + // and routes to either the HTTP handler or a TCP relay. lc := net.ListenConfig{} ln, err := lc.Listen(ctx, "tcp", addr) if err != nil { @@ -240,11 +276,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } + s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid + + // Set up the SNI router for TCP/HTTP multiplexing on the main port. + s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) + s.mainRouter.SetObserver(s.meter) + s.mainRouter.SetAccessLogger(s.accessLog) + close(s.routerReady) + + // The HTTP server uses the chanListener fed by the SNI router. + s.https = &http.Server{ + Addr: addr, + Handler: handler, + TLSConfig: tlsConfig, + ReadHeaderTimeout: httpReadHeaderTimeout, + IdleTimeout: httpIdleTimeout, + ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), + } httpsErr := make(chan error, 1) go func() { - s.Logger.Debugf("starting reverse proxy server on %s", addr) - httpsErr <- s.https.ServeTLS(ln, "", "") + s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") + httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") + }() + + routerErr := make(chan error, 1) + go func() { + s.Logger.Debugf("starting SNI router on %s", addr) + routerErr <- s.mainRouter.Serve(runCtx, ln) }() select { @@ -254,6 +313,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return fmt.Errorf("https server: %w", err) } return nil + case err := <-routerErr: + s.shutdownServices() + if err != nil { + return fmt.Errorf("SNI router: %w", err) + } + return nil case <-ctx.Done(): s.gracefulShutdown() return nil @@ -381,6 +446,13 @@ const ( // shutdownServiceTimeout is the maximum time to wait for auxiliary // services (health probe, debug endpoint, ACME) to shut down. shutdownServiceTimeout = 5 * time.Second + + // httpReadHeaderTimeout limits how long the server waits to read + // request headers after accepting a connection. Prevents slowloris. + httpReadHeaderTimeout = 10 * time.Second + // httpIdleTimeout limits how long an idle keep-alive connection + // stays open before the server closes it. + httpIdleTimeout = 120 * time.Second ) func (s *Server) dialManagement() (*grpc.ClientConn, error) { @@ -518,6 +590,9 @@ func (s *Server) gracefulShutdown() { s.Logger.Infof("closed %d hijacked connection(s)", n) } + // Drain all router relay connections (main + per-port) in parallel. + s.drainAllRouters(shutdownDrainTimeout) + // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") @@ -525,6 +600,34 @@ func (s *Server) gracefulShutdown() { // shutdownServices stops all background services concurrently and waits for // them to finish. +// drainAllRouters drains active relay connections on the main router and +// all per-port routers in parallel, up to the given timeout. +func (s *Server) drainAllRouters(timeout time.Duration) { + var wg sync.WaitGroup + + drain := func(name string, router *nbtcp.Router) { + wg.Add(1) + go func() { + defer wg.Done() + if ok := router.Drain(timeout); !ok { + s.Logger.Warnf("timed out draining %s relay connections", name) + } + }() + } + + if s.mainRouter != nil { + drain("main router", s.mainRouter) + } + + s.portMu.RLock() + for port, pr := range s.portRouters { + drain(fmt.Sprintf("port %d", port), pr.router) + } + s.portMu.RUnlock() + + wg.Wait() +} + func (s *Server) shutdownServices() { var wg sync.WaitGroup @@ -562,9 +665,165 @@ func (s *Server) shutdownServices() { }() } + // Close all UDP relays and wait for their goroutines to exit. + s.udpMu.Lock() + for id, relay := range s.udpRelays { + relay.Close() + delete(s.udpRelays, id) + } + s.udpMu.Unlock() + s.udpRelayWg.Wait() + + // Close all per-port routers. + s.portMu.Lock() + for port, pr := range s.portRouters { + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + } + maps.Clear(s.svcPorts) + maps.Clear(s.lastMappings) + s.portMu.Unlock() + + // Wait for per-port router serve goroutines to exit. + s.portRouterWg.Wait() + wg.Wait() } +// resolveDialFunc returns a DialContextFunc that dials through the +// NetBird tunnel for the given account. +func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) { + client, ok := s.netbird.GetClient(accountID) + if !ok { + return nil, fmt.Errorf("no client for account %s", accountID) + } + return client.DialContext, nil +} + +// notifyError reports a resource error back to management so it can be +// surfaced to the user (e.g. port bind failure, dialer resolution error). +func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { + s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err) +} + +// sendStatusUpdate sends a status update for a service to management. +func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) { + req := &proto.SendStatusUpdateRequest{ + ServiceId: string(serviceID), + AccountId: string(accountID), + Status: st, + } + if err != nil { + msg := err.Error() + req.ErrorMessage = &msg + } + if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil { + s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr) + } +} + +// routerForPort returns the router that handles the given listen port. If port +// is 0 or matches the main listener port, the main router is returned. +// Otherwise a new per-port router is created and started. +func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) { + if port == 0 || port == s.mainPort { + return s.mainRouter, nil + } + return s.getOrCreatePortRouter(ctx, port) +} + +// routerForPortExisting returns the router for the given port without creating +// one. Returns the main router for port 0 / mainPort, or nil if no per-port +// router exists. +func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router { + if port == 0 || port == s.mainPort { + return s.mainRouter + } + s.portMu.RLock() + pr := s.portRouters[port] + s.portMu.RUnlock() + if pr != nil { + return pr.router + } + return nil +} + +// getOrCreatePortRouter returns an existing per-port router or creates one +// with a new TCP listener and starts serving. +func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) { + s.portMu.Lock() + defer s.portMu.Unlock() + + if pr, ok := s.portRouters[port]; ok { + return pr.router, nil + } + + listenAddr := fmt.Sprintf(":%d", port) + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err) + } + if s.ProxyProtocol { + ln = s.wrapProxyProtocol(ln) + } + + router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) + router.SetObserver(s.meter) + router.SetAccessLogger(s.accessLog) + portCtx, cancel := context.WithCancel(ctx) + + s.portRouters[port] = &portRouter{ + router: router, + listener: ln, + cancel: cancel, + } + + s.portRouterWg.Add(1) + go func() { + defer s.portRouterWg.Done() + if err := router.Serve(portCtx, ln); err != nil { + s.Logger.Debugf("port %d router stopped: %v", port, err) + } + }() + + s.Logger.Debugf("started per-port router on %s", listenAddr) + return router, nil +} + +// cleanupPortIfEmpty tears down a per-port router if it has no remaining +// routes or fallback. The main port is never cleaned up. Active relay +// connections are drained before the listener is closed. +func (s *Server) cleanupPortIfEmpty(port uint16) { + if port == 0 || port == s.mainPort { + return + } + + s.portMu.Lock() + pr, ok := s.portRouters[port] + if !ok || !pr.router.IsEmpty() { + s.portMu.Unlock() + return + } + + // Cancel and close the listener while holding the lock so that + // getOrCreatePortRouter sees the entry is gone before we drain. + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + s.portMu.Unlock() + + // Drain active relay connections outside the lock. + if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok { + s.Logger.Warnf("timed out draining relay connections on port %d", port) + } + s.Logger.Debugf("cleaned up empty per-port router on port %d", port) +} + func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { bo := &backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, @@ -590,6 +849,9 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Version: s.Version, StartedAt: timestamppb.New(s.startTime), Address: s.ProxyURL, + Capabilities: &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + }, }) if err != nil { return fmt.Errorf("create mapping stream: %w", err) @@ -626,6 +888,12 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + for { // Check for context completion to gracefully shutdown. select { @@ -662,25 +930,28 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), - "path": mapping.GetPath(), + "mode": mapping.GetMode(), + "port": mapping.GetListenPort(), "id": mapping.GetId(), }).Debug("Processing mapping update") switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { - // TODO: Retry this? Or maybe notify the management server that this mapping has failed? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("Error adding new mapping, ignoring this mapping and continuing processing") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: - if err := s.updateMapping(ctx, mapping); err != nil { + if err := s.modifyMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), - }).Errorf("failed to update mapping: %v", err) + "error": err, + }).Error("failed to modify mapping") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: s.removeMapping(ctx, mapping) @@ -688,30 +959,89 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } +// addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - serviceID := mapping.GetId() + svcID := types.ServiceID(mapping.GetId()) authToken := mapping.GetAuthToken() - if err := s.netbird.AddPeer(ctx, accountID, d, authToken, serviceID); err != nil { - return fmt.Errorf("create peer for domain %q: %w", d, err) - } - var wildcardHit bool - if s.acme != nil { - wildcardHit = s.acme.AddDomain(d, string(accountID), serviceID) + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil { + return fmt.Errorf("create peer for service %s: %w", svcID, err) } - // Pass the mapping through to the update function to avoid duplicating the - // setup, currently update is simply a subset of this function, so this - // separation makes sense...to me at least. + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil { + s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure") + } + return err + } + s.storeMapping(mapping) + return nil +} + +// modifyMapping updates a service mapping in place without tearing down the +// NetBird peer. It cleans up old routes using the previously stored mapping +// state and re-applies them from the new mapping. +func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) + } + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + return err + } + s.storeMapping(mapping) + return nil +} + +// setupMappingRoutes configures the appropriate routes or relays for the given +// service mapping based on its mode. The NetBird peer must already exist. +func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP: + return s.setupTCPMapping(ctx, mapping) + case types.ServiceModeUDP: + return s.setupUDPMapping(ctx, mapping) + case types.ServiceModeTLS: + return s.setupTLSMapping(ctx, mapping) + default: + return s.setupHTTPMapping(ctx, mapping) + } +} + +// setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes. +func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + + if len(mapping.GetPath()) == 0 { + return nil + } + + var wildcardHit bool + if s.acme != nil { + wildcardHit = s.acme.AddDomain(d, accountID, svcID) + } + s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteHTTP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + }) if err := s.updateMapping(ctx, mapping); err != nil { - s.removeMapping(ctx, mapping) return fmt.Errorf("update mapping for domain %q: %w", d, err) } if wildcardHit { - if err := s.NotifyCertificateIssued(ctx, string(accountID), serviceID, string(d)); err != nil { + if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil { s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) } } @@ -719,56 +1049,386 @@ func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) er return nil } +// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port. +func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TCP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TCP service %s", svcID) + } + + if s.WireguardPort != 0 && port == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", port) + } + + router, err := s.routerForPort(ctx, port) + if err != nil { + return fmt.Errorf("router for TCP port %d: %w", port, err) + } + + router.SetFallback(nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTCP, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{port} + s.portMu.Unlock() + + s.meter.L4ServiceAdded(types.ServiceModeTCP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupUDPMapping starts a UDP relay on the listen port. +func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("UDP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for UDP service %s", svcID) + } + + if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { + return fmt.Errorf("UDP relay for service %s: %w", svcID, err) + } + + s.meter.L4ServiceAdded(types.ServiceModeUDP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupTLSMapping configures a TLS SNI-routed passthrough on the listen port. +func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + tlsPort, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TLS service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TLS service %s", svcID) + } + + if s.WireguardPort != 0 && tlsPort == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", tlsPort) + } + + router, err := s.routerForPort(ctx, tlsPort) + if err != nil { + return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) + } + + router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTLS, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + if tlsPort != s.mainPort { + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{tlsPort} + s.portMu.Unlock() + } + + s.Logger.WithFields(log.Fields{ + "domain": mapping.GetDomain(), + "target": targetAddr, + "port": tlsPort, + "service": svcID, + }).Info("TLS passthrough mapping added") + + s.meter.L4ServiceAdded(types.ServiceModeTLS) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// serviceKeyForMapping returns the appropriate ServiceKey for a mapping. +// TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key. +func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP, types.ServiceModeUDP: + return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId())) + default: + return roundtrip.DomainServiceKey(mapping.GetDomain()) + } +} + +// l4TargetAddress extracts and validates the target address from a mapping's +// first path entry. Returns empty string if no paths exist or the address is +// not a valid host:port. +func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string { + paths := mapping.GetPath() + if len(paths) == 0 { + return "" + } + target := paths[0].GetTarget() + if _, _, err := net.SplitHostPort(target); err != nil { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "target": target, + }).Warnf("invalid L4 target address: %v", err) + return "" + } + return target +} + +// l4ProxyProtocol returns whether the first target has PROXY protocol enabled. +func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { + paths := mapping.GetPath() + if len(paths) == 0 { + return false + } + return paths[0].GetOptions().GetProxyProtocol() +} + +// l4DialTimeout returns the dial timeout from the first target's options, +// falling back to the server's DefaultDialTimeout. +func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { + return d.AsDuration() + } + } + return s.DefaultDialTimeout +} + +// l4SessionIdleTimeout returns the configured session idle timeout from the +// mapping options, or 0 to use the relay's default. +func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil { + return d.AsDuration() + } + } + return 0 +} + +// addUDPRelay starts a UDP relay on the specified listen port. +func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + if s.WireguardPort != 0 && listenPort == s.WireguardPort { + return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort) + } + + // Close existing relay if present (idempotent re-add). + s.removeUDPRelay(svcID) + + listenAddr := fmt.Sprintf(":%d", listenPort) + + listener, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return fmt.Errorf("listen UDP on %s: %w", listenAddr, err) + } + + dialFn, err := s.resolveDialFunc(accountID) + if err != nil { + _ = listener.Close() + return fmt.Errorf("resolve dialer for UDP: %w", err) + } + + entry := s.Logger.WithFields(log.Fields{ + "target": targetAddress, + "listen_port": listenPort, + "service_id": svcID, + }) + + relay := udprelay.New(ctx, udprelay.RelayConfig{ + Logger: entry, + Listener: listener, + Target: targetAddress, + Domain: mapping.GetDomain(), + AccountID: accountID, + ServiceID: svcID, + DialFunc: dialFn, + DialTimeout: s.l4DialTimeout(mapping), + SessionTTL: l4SessionIdleTimeout(mapping), + AccessLog: s.accessLog, + }) + relay.SetObserver(s.meter) + + s.udpMu.Lock() + s.udpRelays[svcID] = relay + s.udpMu.Unlock() + + s.udpRelayWg.Go(relay.Serve) + entry.Info("UDP relay added") + return nil +} + func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) error { // Very simple implementation here, we don't touch the existing peer // connection or any existing TLS configuration, we simply overwrite // the auth and proxy mappings. // Note: this does require the management server to always send a // full mapping rather than deltas during a modification. + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + var schemes []auth.Scheme if mapping.GetAuth().GetPassword() { - schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetPin() { - schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetOidc() { - schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), s.ForwardedProto)) + schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) } maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } - s.proxy.AddMapping(s.protoToMapping(mapping)) - s.meter.AddMapping(s.protoToMapping(mapping)) + m := s.protoToMapping(ctx, mapping) + s.proxy.AddMapping(m) + s.meter.AddMapping(m) return nil } +// removeMapping tears down routes/relays and the NetBird peer for a service. +// Uses the stored mapping state when available to ensure all previously +// configured routes are cleaned up. func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil { + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil { s.Logger.WithFields(log.Fields{ "account_id": accountID, - "domain": d, + "service_id": mapping.GetId(), "error": err, - }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") + }).Error("failed to remove NetBird peer, continuing cleanup") } - if s.acme != nil { - s.acme.RemoveDomain(d) + + if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) } - s.auth.RemoveDomain(mapping.GetDomain()) - s.proxy.RemoveMapping(s.protoToMapping(mapping)) - s.meter.RemoveMapping(s.protoToMapping(mapping)) } -func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { +// cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a +// service without touching the NetBird peer. This is used for both full +// removal and in-place modification of mappings. +func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { + svcID := types.ServiceID(mapping.GetId()) + host := mapping.GetDomain() + + // HTTP/TLS cleanup (only relevant when a domain is set). + if host != "" { + d := domain.Domain(host) + if s.acme != nil { + s.acme.RemoveDomain(d) + } + s.auth.RemoveDomain(host) + if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) { + s.meter.RemoveMapping(proxy.Mapping{Host: host}) + } + // Close hijacked connections (WebSocket) for this domain. + if n := s.hijackTracker.CloseByHost(host); n > 0 { + s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host) + } + // Remove SNI route from the main router (covers both HTTP and main-port TLS). + s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) + } + + // Extract and delete tracked custom-port entries atomically. + s.portMu.Lock() + entries := s.svcPorts[svcID] + delete(s.svcPorts, svcID) + s.portMu.Unlock() + + for _, entry := range entries { + if router := s.routerForPortExisting(entry); router != nil { + if host != "" { + router.RemoveRoute(nbtcp.SNIHost(host), svcID) + } else { + router.RemoveFallback(svcID) + } + } + s.cleanupPortIfEmpty(entry) + } + + // UDP relay cleanup (idempotent). + s.removeUDPRelay(svcID) + +} + +// removeUDPRelay stops and removes a UDP relay by service ID. +func (s *Server) removeUDPRelay(svcID types.ServiceID) { + s.udpMu.Lock() + relay, ok := s.udpRelays[svcID] + if ok { + delete(s.udpRelays, svcID) + } + s.udpMu.Unlock() + + if ok { + relay.Close() + s.Logger.WithField("service_id", svcID).Info("UDP relay removed") + } +} + +func (s *Server) storeMapping(mapping *proto.ProxyMapping) { + s.portMu.Lock() + s.lastMappings[types.ServiceID(mapping.GetId())] = mapping + s.portMu.Unlock() +} + +func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.RLock() + m := s.lastMappings[svcID] + s.portMu.RUnlock() + return m +} + +func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.Lock() + m := s.lastMappings[svcID] + delete(s.lastMappings, svcID) + s.portMu.Unlock() + return m +} + +func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping { paths := make(map[string]*proxy.PathTarget) for _, pathMapping := range mapping.GetPath() { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { - // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), @@ -776,6 +1436,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { "path": pathMapping.GetPath(), "target": pathMapping.GetTarget(), }).WithError(err).Error("failed to parse target URL for path, skipping") + s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err)) continue } @@ -788,10 +1449,13 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { pt.RequestTimeout = d.AsDuration() } } + if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 { + pt.RequestTimeout = s.DefaultDialTimeout + } paths[pathMapping.GetPath()] = pt } return proxy.Mapping{ - ID: mapping.GetId(), + ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 9505b3fdf..333f0bf00 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -56,12 +56,14 @@ type ExposeRequest struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ExposeResponse struct { - ServiceName string - Domain string - ServiceURL string + ServiceName string + Domain string + ServiceURL string + PortAutoAssigned bool } // NewClient creates a new client to Management service @@ -790,9 +792,10 @@ func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { func fromProtoExposeResponse(resp *proto.ExposeServiceResponse) *ExposeResponse { return &ExposeResponse{ - ServiceName: resp.ServiceName, - Domain: resp.Domain, - ServiceURL: resp.ServiceUrl, + ServiceName: resp.ServiceName, + Domain: resp.Domain, + ServiceURL: resp.ServiceUrl, + PortAutoAssigned: resp.PortAutoAssigned, } } @@ -808,6 +811,8 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest protocol = proto.ExposeProtocol_EXPOSE_TCP case int(proto.ExposeProtocol_EXPOSE_UDP): protocol = proto.ExposeProtocol_EXPOSE_UDP + case int(proto.ExposeProtocol_EXPOSE_TLS): + protocol = proto.ExposeProtocol_EXPOSE_TLS default: return nil, fmt.Errorf("invalid expose protocol: %d", req.Protocol) } @@ -820,6 +825,7 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: uint32(req.ListenPort), }, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 6d2967aa9..4b851bf19 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2836,6 +2836,10 @@ components: format: int64 description: "Bytes downloaded (response body size)" example: 8192 + protocol: + type: string + description: "Protocol type: http, tcp, or udp" + example: "http" required: - id - service_id @@ -2954,6 +2958,20 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only) + port_auto_assigned: + type: boolean + description: Whether the listen port was auto-assigned + readOnly: true proxy_cluster: type: string description: The proxy cluster handling this service (derived from domain) @@ -3020,6 +3038,16 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. targets: type: array items: @@ -3040,8 +3068,6 @@ components: required: - name - domain - - targets - - auth - enabled ServiceTargetOptions: type: object @@ -3065,6 +3091,12 @@ components: additionalProperties: type: string pattern: '^[^\r\n]*$' + proxy_protocol: + type: boolean + description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) + session_idle_timeout: + type: string + description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. ServiceTarget: type: object properties: @@ -3073,21 +3105,23 @@ components: description: Target ID target_type: type: string - description: Target type (e.g., "peer", "resource") - enum: [peer, resource] + description: Target type + enum: [peer, host, domain, subnet] path: type: string - description: URL path prefix for this target + description: URL path prefix for this target (HTTP only) protocol: type: string description: Protocol to use when connecting to the backend - enum: [http, https] + enum: [http, https, tcp, udp] host: type: string description: Backend ip or domain for this target port: type: integer - description: Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + minimum: 1 + maximum: 65535 + description: Backend port for this target enabled: type: boolean description: Whether this target is enabled @@ -3194,6 +3228,9 @@ components: target_cluster: type: string description: The proxy cluster this domain is validated against (only for custom domains) + supports_custom_ports: + type: boolean + description: Whether the cluster supports binding arbitrary TCP/UDP ports required: - id - domain @@ -4277,6 +4314,12 @@ components: requires_authentication: description: Requires authentication content: { } + conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' securitySchemes: BearerAuth: type: http @@ -9621,6 +9664,29 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services @@ -9670,29 +9736,8 @@ paths: "$ref": "#/components/responses/requires_authentication" '403': "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - /api/reverse-proxies/clusters: - get: - summary: List available proxy clusters - description: Returns a list of available proxy clusters with their connection status - tags: [ Services ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - responses: - '200': - description: A JSON Array of proxy clusters - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/ProxyCluster' - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services/{serviceId}: @@ -9762,6 +9807,8 @@ paths: "$ref": "#/components/responses/forbidden" '404': "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" delete: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f5a2b7ced..4ec3b871a 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -880,6 +880,30 @@ func (e SentinelOneMatchAttributesNetworkStatus) Valid() bool { } } +// Defines values for ServiceMode. +const ( + ServiceModeHttp ServiceMode = "http" + ServiceModeTcp ServiceMode = "tcp" + ServiceModeTls ServiceMode = "tls" + ServiceModeUdp ServiceMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceMode enum. +func (e ServiceMode) Valid() bool { + switch e { + case ServiceModeHttp: + return true + case ServiceModeTcp: + return true + case ServiceModeTls: + return true + case ServiceModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceMetaStatus. const ( ServiceMetaStatusActive ServiceMetaStatus = "active" @@ -910,10 +934,36 @@ func (e ServiceMetaStatus) Valid() bool { } } +// Defines values for ServiceRequestMode. +const ( + ServiceRequestModeHttp ServiceRequestMode = "http" + ServiceRequestModeTcp ServiceRequestMode = "tcp" + ServiceRequestModeTls ServiceRequestMode = "tls" + ServiceRequestModeUdp ServiceRequestMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceRequestMode enum. +func (e ServiceRequestMode) Valid() bool { + switch e { + case ServiceRequestModeHttp: + return true + case ServiceRequestModeTcp: + return true + case ServiceRequestModeTls: + return true + case ServiceRequestModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceTargetProtocol. const ( ServiceTargetProtocolHttp ServiceTargetProtocol = "http" ServiceTargetProtocolHttps ServiceTargetProtocol = "https" + ServiceTargetProtocolTcp ServiceTargetProtocol = "tcp" + ServiceTargetProtocolUdp ServiceTargetProtocol = "udp" ) // Valid indicates whether the value is a known member of the ServiceTargetProtocol enum. @@ -923,6 +973,10 @@ func (e ServiceTargetProtocol) Valid() bool { return true case ServiceTargetProtocolHttps: return true + case ServiceTargetProtocolTcp: + return true + case ServiceTargetProtocolUdp: + return true default: return false } @@ -930,16 +984,22 @@ func (e ServiceTargetProtocol) Valid() bool { // Defines values for ServiceTargetTargetType. const ( - ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" - ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource" + ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" + ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" + ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" + ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" ) // Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. func (e ServiceTargetTargetType) Valid() bool { switch e { + case ServiceTargetTargetTypeDomain: + return true + case ServiceTargetTargetTypeHost: + return true case ServiceTargetTargetTypePeer: return true - case ServiceTargetTargetTypeResource: + case ServiceTargetTargetTypeSubnet: return true default: return false @@ -3249,6 +3309,9 @@ type ProxyAccessLog struct { // Path Path of the request Path string `json:"path"` + // Protocol Protocol type: http, tcp, or udp + Protocol *string `json:"protocol,omitempty"` + // Reason Reason for the request result (e.g., authentication failure) Reason *string `json:"reason,omitempty"` @@ -3313,6 +3376,9 @@ type ReverseProxyDomain struct { // Id Domain ID Id string `json:"id"` + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports + SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + // TargetCluster The proxy cluster this domain is validated against (only for custom domains) TargetCluster *string `json:"target_cluster,omitempty"` @@ -3505,8 +3571,14 @@ type Service struct { Enabled bool `json:"enabled"` // Id Service ID - Id string `json:"id"` - Meta ServiceMeta `json:"meta"` + Id string `json:"id"` + + // ListenPort Port the proxy listens on (L4/TLS only) + ListenPort *int `json:"listen_port,omitempty"` + Meta ServiceMeta `json:"meta"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceMode `json:"mode,omitempty"` // Name Service name Name string `json:"name"` @@ -3514,6 +3586,9 @@ type Service struct { // PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address PassHostHeader *bool `json:"pass_host_header,omitempty"` + // PortAutoAssigned Whether the listen port was auto-assigned + PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"` + // ProxyCluster The proxy cluster handling this service (derived from domain) ProxyCluster *string `json:"proxy_cluster,omitempty"` @@ -3524,6 +3599,9 @@ type Service struct { Targets []ServiceTarget `json:"targets"` } +// ServiceMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceMode string + // ServiceAuthConfig defines model for ServiceAuthConfig. type ServiceAuthConfig struct { BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` @@ -3549,7 +3627,7 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { - Auth ServiceAuthConfig `json:"auth"` + Auth *ServiceAuthConfig `json:"auth,omitempty"` // Domain Domain for the service Domain string `json:"domain"` @@ -3557,6 +3635,12 @@ type ServiceRequest struct { // Enabled Whether the service is enabled Enabled bool `json:"enabled"` + // ListenPort Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + ListenPort *int `json:"listen_port,omitempty"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceRequestMode `json:"mode,omitempty"` + // Name Service name Name string `json:"name"` @@ -3567,9 +3651,12 @@ type ServiceRequest struct { RewriteRedirects *bool `json:"rewrite_redirects,omitempty"` // Targets List of target backends for this service - Targets []ServiceTarget `json:"targets"` + Targets *[]ServiceTarget `json:"targets,omitempty"` } +// ServiceRequestMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceRequestMode string + // ServiceTarget defines model for ServiceTarget. type ServiceTarget struct { // Enabled Whether this target is enabled @@ -3579,10 +3666,10 @@ type ServiceTarget struct { Host *string `json:"host,omitempty"` Options *ServiceTargetOptions `json:"options,omitempty"` - // Path URL path prefix for this target + // Path URL path prefix for this target (HTTP only) Path *string `json:"path,omitempty"` - // Port Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + // Port Backend port for this target Port int `json:"port"` // Protocol Protocol to use when connecting to the backend @@ -3591,14 +3678,14 @@ type ServiceTarget struct { // TargetId Target ID TargetId string `json:"target_id"` - // TargetType Target type (e.g., "peer", "resource") + // TargetType Target type TargetType ServiceTargetTargetType `json:"target_type"` } // ServiceTargetProtocol Protocol to use when connecting to the backend type ServiceTargetProtocol string -// ServiceTargetTargetType Target type (e.g., "peer", "resource") +// ServiceTargetTargetType Target type type ServiceTargetTargetType string // ServiceTargetOptions defines model for ServiceTargetOptions. @@ -3609,9 +3696,15 @@ type ServiceTargetOptions struct { // PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"` + // ProxyProtocol Send PROXY Protocol v2 header to this backend (TCP/TLS only) + ProxyProtocol *bool `json:"proxy_protocol,omitempty"` + // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") RequestTimeout *string `json:"request_timeout,omitempty"` + // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. + SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` + // SkipTlsVerify Skip TLS certificate verification for this backend SkipTlsVerify *bool `json:"skip_tls_verify,omitempty"` } @@ -4136,6 +4229,9 @@ type ZoneRequest struct { Name string `json:"name"` } +// Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. +type Conflict = ErrorResponse + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2c66bb946..c5581296c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -228,6 +228,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -237,12 +238,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -4047,6 +4050,7 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ExposeServiceRequest) Reset() { @@ -4130,14 +4134,22 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` } func (x *ExposeServiceResponse) Reset() { @@ -4193,6 +4205,13 @@ func (x *ExposeServiceResponse) GetDomain() string { return "" } +func (x *ExposeServiceResponse) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type RenewExposeRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -4996,7 +5015,7 @@ var file_management_proto_rawDesc = []byte{ 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0xea, 0x01, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, @@ -5010,15 +5029,20 @@ var file_management_proto_rawDesc = []byte{ 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0x73, - 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, + 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, + 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, + 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, + 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, + 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, + 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, @@ -5039,12 +5063,13 @@ var file_management_proto_rawDesc = []byte{ 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, - 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x53, 0x0a, 0x0e, 0x45, + 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index fdbe3a365..9acf7e2b3 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -652,6 +652,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -662,12 +663,14 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceResponse { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } message RenewExposeRequest { diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 275e8be37..115ac5101 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.33.3 // source: proxy_service.proto package proto @@ -175,22 +175,72 @@ func (ProxyStatus) EnumDescriptor() ([]byte, []int) { return file_proxy_service_proto_rawDescGZIP(), []int{2} } +// ProxyCapabilities describes what a proxy can handle. +type ProxyCapabilities struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` +} + +func (x *ProxyCapabilities) Reset() { + *x = ProxyCapabilities{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProxyCapabilities) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProxyCapabilities) ProtoMessage() {} + +func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProxyCapabilities.ProtoReflect.Descriptor instead. +func (*ProxyCapabilities) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{0} +} + +func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { + if x != nil && x.SupportsCustomPorts != nil { + return *x.SupportsCustomPorts + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -203,7 +253,7 @@ func (x *GetMappingUpdateRequest) String() string { func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -216,7 +266,7 @@ func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateRequest.ProtoReflect.Descriptor instead. func (*GetMappingUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{0} + return file_proxy_service_proto_rawDescGZIP(), []int{1} } func (x *GetMappingUpdateRequest) GetProxyId() string { @@ -247,6 +297,13 @@ func (x *GetMappingUpdateRequest) GetAddress() string { return "" } +func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + // GetMappingUpdateResponse contains zero or more ProxyMappings. // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. @@ -264,7 +321,7 @@ type GetMappingUpdateResponse struct { func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -277,7 +334,7 @@ func (x *GetMappingUpdateResponse) String() string { func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -290,7 +347,7 @@ func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateResponse.ProtoReflect.Descriptor instead. func (*GetMappingUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{1} + return file_proxy_service_proto_rawDescGZIP(), []int{2} } func (x *GetMappingUpdateResponse) GetMapping() []*ProxyMapping { @@ -316,12 +373,16 @@ type PathTargetOptions struct { RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Send PROXY protocol v2 header to this backend. + ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` + // Idle timeout before a UDP session is reaped. + SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` } func (x *PathTargetOptions) Reset() { *x = PathTargetOptions{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -334,7 +395,7 @@ func (x *PathTargetOptions) String() string { func (*PathTargetOptions) ProtoMessage() {} func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -347,7 +408,7 @@ func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { // Deprecated: Use PathTargetOptions.ProtoReflect.Descriptor instead. func (*PathTargetOptions) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{2} + return file_proxy_service_proto_rawDescGZIP(), []int{3} } func (x *PathTargetOptions) GetSkipTlsVerify() bool { @@ -378,6 +439,20 @@ func (x *PathTargetOptions) GetCustomHeaders() map[string]string { return nil } +func (x *PathTargetOptions) GetProxyProtocol() bool { + if x != nil { + return x.ProxyProtocol + } + return false +} + +func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { + if x != nil { + return x.SessionIdleTimeout + } + return nil +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -391,7 +466,7 @@ type PathMapping struct { func (x *PathMapping) Reset() { *x = PathMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -404,7 +479,7 @@ func (x *PathMapping) String() string { func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -417,7 +492,7 @@ func (x *PathMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use PathMapping.ProtoReflect.Descriptor instead. func (*PathMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{3} + return file_proxy_service_proto_rawDescGZIP(), []int{4} } func (x *PathMapping) GetPath() string { @@ -456,7 +531,7 @@ type Authentication struct { func (x *Authentication) Reset() { *x = Authentication{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -469,7 +544,7 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -482,7 +557,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{4} + return file_proxy_service_proto_rawDescGZIP(), []int{5} } func (x *Authentication) GetSessionKey() string { @@ -538,12 +613,16 @@ type ProxyMapping struct { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. RewriteRedirects bool `protobuf:"varint,9,opt,name=rewrite_redirects,json=rewriteRedirects,proto3" json:"rewrite_redirects,omitempty"` + // Service mode: "http", "tcp", "udp", or "tls". + Mode string `protobuf:"bytes,10,opt,name=mode,proto3" json:"mode,omitempty"` + // For L4/TLS: the port the proxy listens on. + ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -556,7 +635,7 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -569,7 +648,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -635,6 +714,20 @@ func (x *ProxyMapping) GetRewriteRedirects() bool { return false } +func (x *ProxyMapping) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *ProxyMapping) GetListenPort() int32 { + if x != nil { + return x.ListenPort + } + return 0 +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { state protoimpl.MessageState @@ -647,7 +740,7 @@ type SendAccessLogRequest struct { func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -660,7 +753,7 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -673,7 +766,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{7} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -693,7 +786,7 @@ type SendAccessLogResponse struct { func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -706,7 +799,7 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -719,7 +812,7 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } type AccessLog struct { @@ -742,12 +835,13 @@ type AccessLog struct { AuthSuccess bool `protobuf:"varint,13,opt,name=auth_success,json=authSuccess,proto3" json:"auth_success,omitempty"` BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` + Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` } func (x *AccessLog) Reset() { *x = AccessLog{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -760,7 +854,7 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -773,7 +867,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -881,6 +975,13 @@ func (x *AccessLog) GetBytesDownload() int64 { return 0 } +func (x *AccessLog) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -898,7 +999,7 @@ type AuthenticateRequest struct { func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -911,7 +1012,7 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -924,7 +1025,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } func (x *AuthenticateRequest) GetId() string { @@ -989,7 +1090,7 @@ type PasswordRequest struct { func (x *PasswordRequest) Reset() { *x = PasswordRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1002,7 +1103,7 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1015,7 +1116,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } func (x *PasswordRequest) GetPassword() string { @@ -1036,7 +1137,7 @@ type PinRequest struct { func (x *PinRequest) Reset() { *x = PinRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1049,7 +1150,7 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1062,7 +1163,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *PinRequest) GetPin() string { @@ -1084,7 +1185,7 @@ type AuthenticateResponse struct { func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1097,7 +1198,7 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1110,7 +1211,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{13} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1143,7 +1244,7 @@ type SendStatusUpdateRequest struct { func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1156,7 +1257,7 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1169,7 +1270,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1217,7 +1318,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1230,7 +1331,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1243,7 +1344,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1263,7 +1364,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1276,7 +1377,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1289,7 +1390,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1340,7 +1441,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1353,7 +1454,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1366,7 +1467,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1396,7 +1497,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1409,7 +1510,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1422,7 +1523,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } func (x *GetOIDCURLRequest) GetId() string { @@ -1457,7 +1558,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1470,7 +1571,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1483,7 +1584,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1505,7 +1606,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1518,7 +1619,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1531,7 +1632,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1562,7 +1663,7 @@ type ValidateSessionResponse struct { func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1575,7 +1676,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1588,7 +1689,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1628,124 +1729,147 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xa3, 0x01, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, - 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, - 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, - 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, - 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xda, 0x02, - 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, - 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x54, 0x6c, 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, - 0x3e, 0x0a, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, - 0x64, 0x65, 0x52, 0x0b, 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, - 0x57, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, - 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, - 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, - 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, - 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, - 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, - 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0xe0, 0x02, 0x0a, 0x0c, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, - 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, - 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, - 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, - 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, - 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, - 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x22, 0x3f, - 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, - 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xea, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, - 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, - 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, - 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, - 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, - 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, - 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, - 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, - 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, - 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, - 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x6f, 0x22, 0x66, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, 0x01, + 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x22, 0xe6, 0x01, 0x0a, 0x17, 0x47, + 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, + 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, + 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, + 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, + 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, + 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, + 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, + 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, 0x61, + 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, + 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x70, + 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, + 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, 0x6f, + 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, + 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, + 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, 0x74, + 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, + 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, + 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, 0x01, + 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, + 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, + 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0x95, 0x03, 0x0a, 0x0c, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, + 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, + 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, + 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, + 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x12, 0x12, 0x0a, + 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6d, 0x6f, 0x64, + 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, + 0x72, 0x74, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, + 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x86, 0x04, 0x0a, + 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, + 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, + 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, + 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, @@ -1907,70 +2031,73 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 22) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 23) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode (ProxyStatus)(0), // 2: management.ProxyStatus - (*GetMappingUpdateRequest)(nil), // 3: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 4: management.GetMappingUpdateResponse - (*PathTargetOptions)(nil), // 5: management.PathTargetOptions - (*PathMapping)(nil), // 6: management.PathMapping - (*Authentication)(nil), // 7: management.Authentication - (*ProxyMapping)(nil), // 8: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 9: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 10: management.SendAccessLogResponse - (*AccessLog)(nil), // 11: management.AccessLog - (*AuthenticateRequest)(nil), // 12: management.AuthenticateRequest - (*PasswordRequest)(nil), // 13: management.PasswordRequest - (*PinRequest)(nil), // 14: management.PinRequest - (*AuthenticateResponse)(nil), // 15: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 16: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 17: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 18: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 19: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 20: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 21: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 22: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 23: management.ValidateSessionResponse - nil, // 24: management.PathTargetOptions.CustomHeadersEntry - (*timestamppb.Timestamp)(nil), // 25: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 26: google.protobuf.Duration + (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 6: management.PathTargetOptions + (*PathMapping)(nil), // 7: management.PathMapping + (*Authentication)(nil), // 8: management.Authentication + (*ProxyMapping)(nil), // 9: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 10: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 11: management.SendAccessLogResponse + (*AccessLog)(nil), // 12: management.AccessLog + (*AuthenticateRequest)(nil), // 13: management.AuthenticateRequest + (*PasswordRequest)(nil), // 14: management.PasswordRequest + (*PinRequest)(nil), // 15: management.PinRequest + (*AuthenticateResponse)(nil), // 16: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 17: management.SendStatusUpdateRequest + (*SendStatusUpdateResponse)(nil), // 18: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 19: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 20: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 21: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 22: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 23: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 24: management.ValidateSessionResponse + nil, // 25: management.PathTargetOptions.CustomHeadersEntry + (*timestamppb.Timestamp)(nil), // 26: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 27: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 25, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp - 8, // 1: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 26, // 2: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration - 1, // 3: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 24, // 4: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 5, // 5: management.PathMapping.options:type_name -> management.PathTargetOptions - 0, // 6: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 6, // 7: management.ProxyMapping.path:type_name -> management.PathMapping - 7, // 8: management.ProxyMapping.auth:type_name -> management.Authentication - 11, // 9: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 25, // 10: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 13, // 11: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 14, // 12: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 2, // 13: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 3, // 14: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 9, // 15: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 12, // 16: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 16, // 17: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 18, // 18: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 20, // 19: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 22, // 20: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 4, // 21: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 10, // 22: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 15, // 23: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 17, // 24: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 19, // 25: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 21, // 26: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 23, // 27: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 21, // [21:28] is the sub-list for method output_type - 14, // [14:21] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 26, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities + 9, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 27, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode + 25, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 27, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions + 0, // 8: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 7, // 9: management.ProxyMapping.path:type_name -> management.PathMapping + 8, // 10: management.ProxyMapping.auth:type_name -> management.Authentication + 12, // 11: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 26, // 12: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 14, // 13: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 15, // 14: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 2, // 15: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 16: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 10, // 17: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 13, // 18: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 17, // 19: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 19, // 20: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 21, // 21: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 23, // 22: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 23: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 11, // 24: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 16, // 25: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 18, // 26: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 20, // 27: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 22, // 28: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 24, // 29: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 23, // [23:30] is the sub-list for method output_type + 16, // [16:23] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -1980,7 +2107,7 @@ func file_proxy_service_proto_init() { } if !protoimpl.UnsafeEnabled { file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateRequest); i { + switch v := v.(*ProxyCapabilities); i { case 0: return &v.state case 1: @@ -1992,7 +2119,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateResponse); i { + switch v := v.(*GetMappingUpdateRequest); i { case 0: return &v.state case 1: @@ -2004,7 +2131,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathTargetOptions); i { + switch v := v.(*GetMappingUpdateResponse); i { case 0: return &v.state case 1: @@ -2016,7 +2143,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { + switch v := v.(*PathTargetOptions); i { case 0: return &v.state case 1: @@ -2028,7 +2155,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { + switch v := v.(*PathMapping); i { case 0: return &v.state case 1: @@ -2040,7 +2167,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { + switch v := v.(*Authentication); i { case 0: return &v.state case 1: @@ -2052,7 +2179,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { + switch v := v.(*ProxyMapping); i { case 0: return &v.state case 1: @@ -2064,7 +2191,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { + switch v := v.(*SendAccessLogRequest); i { case 0: return &v.state case 1: @@ -2076,7 +2203,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { + switch v := v.(*SendAccessLogResponse); i { case 0: return &v.state case 1: @@ -2088,7 +2215,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { + switch v := v.(*AccessLog); i { case 0: return &v.state case 1: @@ -2100,7 +2227,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { + switch v := v.(*AuthenticateRequest); i { case 0: return &v.state case 1: @@ -2112,7 +2239,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { + switch v := v.(*PasswordRequest); i { case 0: return &v.state case 1: @@ -2124,7 +2251,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { + switch v := v.(*PinRequest); i { case 0: return &v.state case 1: @@ -2136,7 +2263,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { + switch v := v.(*AuthenticateResponse); i { case 0: return &v.state case 1: @@ -2148,7 +2275,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*SendStatusUpdateRequest); i { case 0: return &v.state case 1: @@ -2160,7 +2287,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -2172,7 +2299,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*CreateProxyPeerRequest); i { case 0: return &v.state case 1: @@ -2184,7 +2311,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*CreateProxyPeerResponse); i { case 0: return &v.state case 1: @@ -2196,7 +2323,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*GetOIDCURLRequest); i { case 0: return &v.state case 1: @@ -2208,7 +2335,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*GetOIDCURLResponse); i { case 0: return &v.state case 1: @@ -2220,6 +2347,18 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state @@ -2232,19 +2371,20 @@ func file_proxy_service_proto_init() { } } } - file_proxy_service_proto_msgTypes[9].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[10].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), } - file_proxy_service_proto_msgTypes[13].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[16].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[14].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 22, + NumMessages: 23, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 195b60f01..457d12e85 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -27,12 +27,19 @@ service ProxyService { rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); } +// ProxyCapabilities describes what a proxy can handle. +message ProxyCapabilities { + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + optional bool supports_custom_ports = 1; +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. message GetMappingUpdateRequest { string proxy_id = 1; string version = 2; google.protobuf.Timestamp started_at = 3; string address = 4; + ProxyCapabilities capabilities = 5; } // GetMappingUpdateResponse contains zero or more ProxyMappings. @@ -61,6 +68,10 @@ message PathTargetOptions { google.protobuf.Duration request_timeout = 2; PathRewriteMode path_rewrite = 3; map custom_headers = 4; + // Send PROXY protocol v2 header to this backend. + bool proxy_protocol = 5; + // Idle timeout before a UDP session is reaped. + google.protobuf.Duration session_idle_timeout = 6; } message PathMapping { @@ -91,6 +102,10 @@ message ProxyMapping { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. bool rewrite_redirects = 9; + // Service mode: "http", "tcp", "udp", or "tls". + string mode = 10; + // For L4/TLS: the port the proxy listens on. + int32 listen_port = 11; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -117,6 +132,7 @@ message AccessLog { bool auth_success = 13; int64 bytes_upload = 14; int64 bytes_download = 15; + string protocol = 16; } message AuthenticateRequest {
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}