mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-23 17:19:54 +00:00
Compare commits
12 Commits
windows-dn
...
add-json-y
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fd862e157 | ||
|
|
75214223d7 | ||
|
|
fd0834441d | ||
|
|
d3293fb282 | ||
|
|
a212963dac | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca |
@@ -92,9 +92,6 @@ linters:
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -108,16 +109,41 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if uploadBundleFlag {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
return emitDebugBundle(cmd, resp.GetPath(), resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
func emitDebugBundle(cmd *cobra.Command, path, uploadedKey string) error {
|
||||
if !jsonFlag && !yamlFlag {
|
||||
cmd.Printf("Local file:\n%s\n", path)
|
||||
if uploadBundleFlag {
|
||||
cmd.Printf("Upload file key:\n%s\n", uploadedKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
out := &nbstatus.DebugBundleOutput{Path: path}
|
||||
if uploadBundleFlag {
|
||||
out.UploadedKey = uploadedKey
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -451,6 +477,9 @@ func init() {
|
||||
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
debugBundleCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
debugBundleCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
debugBundleCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
|
||||
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
var downCmd = &cobra.Command{
|
||||
@@ -44,7 +45,29 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Disconnected")
|
||||
out := &nbstatus.DownOutput{Status: "Disconnected"}
|
||||
switch {
|
||||
case jsonFlag:
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
case yamlFlag:
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
default:
|
||||
cmd.Println(out.Status)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
downCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
downCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
downCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
var forwardingRulesCmd = &cobra.Command{
|
||||
@@ -25,6 +26,12 @@ var forwardingRulesListCmd = &cobra.Command{
|
||||
RunE: listForwardingRules,
|
||||
}
|
||||
|
||||
func init() {
|
||||
forwardingRulesListCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
forwardingRulesListCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
forwardingRulesListCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
@@ -38,19 +45,23 @@ func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.GetRules()) == 0 {
|
||||
rules := resp.GetRules()
|
||||
sortForwardingRules(rules)
|
||||
|
||||
if jsonFlag || yamlFlag {
|
||||
return emitForwardingList(cmd, rules)
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
cmd.Println("No forwarding rules available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printForwardingRules(cmd, resp.GetRules())
|
||||
printForwardingRules(cmd, rules)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||
cmd.Println("Available forwarding rules:")
|
||||
|
||||
// Sort rules by translated address
|
||||
func sortForwardingRules(rules []*proto.ForwardingRule) {
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||
@@ -58,9 +69,45 @@ func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||
}
|
||||
|
||||
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||
})
|
||||
}
|
||||
|
||||
func emitForwardingList(cmd *cobra.Command, rules []*proto.ForwardingRule) error {
|
||||
out := &nbstatus.ForwardingListOutput{Rules: make([]nbstatus.ForwardingRuleOutput, 0, len(rules))}
|
||||
for _, rule := range rules {
|
||||
row := nbstatus.ForwardingRuleOutput{
|
||||
TranslatedAddress: rule.GetTranslatedAddress(),
|
||||
TranslatedHostname: rule.GetTranslatedHostname(),
|
||||
Protocol: rule.GetProtocol(),
|
||||
}
|
||||
if s, ok := portToStringOpt(rule.GetDestinationPort()); ok {
|
||||
row.DestinationPort = &s
|
||||
}
|
||||
if s, ok := portToStringOpt(rule.GetTranslatedPort()); ok {
|
||||
row.TranslatedPort = &s
|
||||
}
|
||||
out.Rules = append(out.Rules, row)
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||
cmd.Println("Available forwarding rules:")
|
||||
|
||||
var lastIP string
|
||||
for _, rule := range rules {
|
||||
@@ -87,12 +134,22 @@ func getFirstPort(portInfo *proto.PortInfo) int {
|
||||
}
|
||||
|
||||
func portToString(translatedPort *proto.PortInfo) string {
|
||||
switch v := translatedPort.PortSelection.(type) {
|
||||
if s, ok := portToStringOpt(translatedPort); ok {
|
||||
return s
|
||||
}
|
||||
return "No port specified"
|
||||
}
|
||||
|
||||
// portToStringOpt returns the formatted port string and whether port info was
|
||||
// actually present. Used by the structured (json/yaml) output so the absent
|
||||
// case becomes a missing field instead of a sentinel string.
|
||||
func portToStringOpt(p *proto.PortInfo) (string, bool) {
|
||||
switch v := p.GetPortSelection().(type) {
|
||||
case *proto.PortInfo_Port:
|
||||
return fmt.Sprintf("%d", v.Port)
|
||||
return fmt.Sprintf("%d", v.Port), true
|
||||
case *proto.PortInfo_Range_:
|
||||
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd()), true
|
||||
default:
|
||||
return "No port specified"
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -27,6 +28,10 @@ func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
|
||||
loginCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
loginCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
loginCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -73,12 +78,36 @@ var loginCmd = &cobra.Command{
|
||||
return fmt.Errorf("daemon login failed: %v", err)
|
||||
}
|
||||
|
||||
cmd.Println("Logging successfully")
|
||||
|
||||
return nil
|
||||
return emitLoginOutput(cmd, activeProf.Name)
|
||||
},
|
||||
}
|
||||
|
||||
// emitLoginOutput writes the result of a successful login in the format
|
||||
// requested by the user (json, yaml, or human-readable text fallback).
|
||||
func emitLoginOutput(cmd *cobra.Command, profileName string) error {
|
||||
out := &nbstatus.LoginOutput{
|
||||
Status: "logged_in",
|
||||
ProfileName: profileName,
|
||||
}
|
||||
switch {
|
||||
case jsonFlag:
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
case yamlFlag:
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
default:
|
||||
cmd.Println("Logging successfully")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
@@ -253,8 +282,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
cmd.Println("Logging successfully")
|
||||
return nil
|
||||
return emitLoginOutput(cmd, activeProf.Name)
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
@@ -337,6 +365,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
if jsonFlag || yamlFlag {
|
||||
emitSSOEvent(cmd, verificationURIComplete, userCode)
|
||||
return
|
||||
}
|
||||
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -366,6 +399,33 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
}
|
||||
}
|
||||
|
||||
// emitSSOEvent writes the verification URL/code as a structured event for
|
||||
// callers using --json or --yaml. The browser is intentionally not opened in
|
||||
// this mode since automation contexts (CI, scripts) typically run headless.
|
||||
func emitSSOEvent(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
event := &nbstatus.SSOEvent{
|
||||
Event: "sso_required",
|
||||
VerificationURIComplete: verificationURIComplete,
|
||||
UserCode: userCode,
|
||||
}
|
||||
if jsonFlag {
|
||||
s, err := event.JSON()
|
||||
if err != nil {
|
||||
log.Errorf("marshal sso event: %v", err)
|
||||
return
|
||||
}
|
||||
cmd.Println(s)
|
||||
return
|
||||
}
|
||||
s, err := event.YAML()
|
||||
if err != nil {
|
||||
log.Errorf("marshal sso event: %v", err)
|
||||
return
|
||||
}
|
||||
cmd.Print(s)
|
||||
cmd.Println("---")
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
var logoutCmd = &cobra.Command{
|
||||
@@ -49,11 +50,33 @@ var logoutCmd = &cobra.Command{
|
||||
return fmt.Errorf("deregister: %v", err)
|
||||
}
|
||||
|
||||
cmd.Println("Deregistered successfully")
|
||||
out := &nbstatus.DeregisterOutput{
|
||||
Status: "deregistered",
|
||||
ProfileName: profileName,
|
||||
}
|
||||
switch {
|
||||
case jsonFlag:
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
case yamlFlag:
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
default:
|
||||
cmd.Println("Deregistered successfully")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
logoutCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
logoutCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
logoutCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
logoutCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
var appendFlag bool
|
||||
@@ -48,6 +49,12 @@ var routesDeselectCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
|
||||
|
||||
for _, c := range []*cobra.Command{routesListCmd, routesSelectCmd, routesDeselectCmd} {
|
||||
c.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
c.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
c.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func networksList(cmd *cobra.Command, _ []string) error {
|
||||
@@ -63,6 +70,10 @@ func networksList(cmd *cobra.Command, _ []string) error {
|
||||
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if jsonFlag || yamlFlag {
|
||||
return emitNetworksList(cmd, resp)
|
||||
}
|
||||
|
||||
if len(resp.Routes) == 0 {
|
||||
cmd.Println("No networks available.")
|
||||
return nil
|
||||
@@ -73,6 +84,40 @@ func networksList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitNetworksList(cmd *cobra.Command, resp *proto.ListNetworksResponse) error {
|
||||
out := &nbstatus.NetworksListOutput{Networks: make([]nbstatus.NetworkOutput, 0, len(resp.GetRoutes()))}
|
||||
for _, route := range resp.GetRoutes() {
|
||||
row := nbstatus.NetworkOutput{
|
||||
ID: route.GetID(),
|
||||
Range: route.GetRange(),
|
||||
Domains: route.GetDomains(),
|
||||
Selected: route.GetSelected(),
|
||||
}
|
||||
if resolved := route.GetResolvedIPs(); len(resolved) > 0 {
|
||||
row.ResolvedIPs = make(map[string][]string, len(resolved))
|
||||
for d, ipList := range resolved {
|
||||
row.ResolvedIPs[d] = ipList.GetIps()
|
||||
}
|
||||
}
|
||||
out.Networks = append(out.Networks, row)
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||
cmd.Println("Available Networks:")
|
||||
for _, route := range resp.Routes {
|
||||
@@ -142,11 +187,10 @@ func networksSelect(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Networks selected successfully.")
|
||||
|
||||
return nil
|
||||
return emitNetworksMutation(cmd, "selected", req, "Networks selected successfully.")
|
||||
}
|
||||
|
||||
|
||||
func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
@@ -167,7 +211,35 @@ func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Networks deselected successfully.")
|
||||
return emitNetworksMutation(cmd, "deselected", req, "Networks deselected successfully.")
|
||||
}
|
||||
|
||||
func emitNetworksMutation(cmd *cobra.Command, action string, req *proto.SelectNetworksRequest, textFallback string) error {
|
||||
if !jsonFlag && !yamlFlag {
|
||||
cmd.Println(textFallback)
|
||||
return nil
|
||||
}
|
||||
|
||||
out := &nbstatus.NetworksMutationOutput{
|
||||
Status: action,
|
||||
All: req.GetAll(),
|
||||
}
|
||||
if !req.GetAll() {
|
||||
out.Networks = req.GetNetworkIDs()
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
for _, c := range []*cobra.Command{profileListCmd, profileAddCmd, profileRemoveCmd, profileSelectCmd} {
|
||||
c.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
c.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
c.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
}
|
||||
|
||||
var profileCmd = &cobra.Command{
|
||||
Use: "profile",
|
||||
Short: "Manage NetBird client profiles",
|
||||
@@ -90,6 +99,10 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if jsonFlag || yamlFlag {
|
||||
return emitProfileList(cmd, profiles.GetProfiles())
|
||||
}
|
||||
|
||||
// list profiles, add a tick if the profile is active
|
||||
cmd.Println("Found", len(profiles.Profiles), "profiles:")
|
||||
for _, profile := range profiles.Profiles {
|
||||
@@ -104,6 +117,58 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitProfileMutation(cmd *cobra.Command, action, profile, textFallback string) error {
|
||||
if !jsonFlag && !yamlFlag {
|
||||
cmd.Println(textFallback)
|
||||
return nil
|
||||
}
|
||||
|
||||
out := &nbstatus.ProfileMutationOutput{
|
||||
Status: action,
|
||||
ProfileName: profile,
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitProfileList(cmd *cobra.Command, profiles []*proto.Profile) error {
|
||||
out := &nbstatus.ProfileListOutput{Profiles: make([]nbstatus.ProfileOutput, 0, len(profiles))}
|
||||
for _, p := range profiles {
|
||||
out.Profiles = append(out.Profiles, nbstatus.ProfileOutput{
|
||||
Name: p.GetName(),
|
||||
Active: p.GetIsActive(),
|
||||
})
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -132,8 +197,7 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile added successfully:", profileName)
|
||||
return nil
|
||||
return emitProfileMutation(cmd, "added", profileName, "Profile added successfully: "+profileName)
|
||||
}
|
||||
|
||||
func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -164,8 +228,7 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile removed successfully:", profileName)
|
||||
return nil
|
||||
return emitProfileMutation(cmd, "removed", profileName, "Profile removed successfully: "+profileName)
|
||||
}
|
||||
|
||||
func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -231,6 +294,5 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Profile switched successfully to:", profileName)
|
||||
return nil
|
||||
return emitProfileMutation(cmd, "selected", profileName, "Profile switched successfully to: "+profileName)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -75,6 +76,10 @@ func init() {
|
||||
|
||||
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
|
||||
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
|
||||
|
||||
stateListCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
stateListCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
stateListCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
func stateList(cmd *cobra.Command, _ []string) error {
|
||||
@@ -94,6 +99,10 @@ func stateList(cmd *cobra.Command, _ []string) error {
|
||||
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if jsonFlag || yamlFlag {
|
||||
return emitStateList(cmd, resp.GetStates())
|
||||
}
|
||||
|
||||
cmd.Printf("\nStored states:\n\n")
|
||||
for _, state := range resp.States {
|
||||
cmd.Printf("- %s\n", state.Name)
|
||||
@@ -102,6 +111,28 @@ func stateList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitStateList(cmd *cobra.Command, states []*proto.State) error {
|
||||
out := &nbstatus.StateListOutput{States: make([]nbstatus.StateOutput, 0, len(states))}
|
||||
for _, s := range states {
|
||||
out.States = append(out.States, nbstatus.StateOutput{Name: s.GetName()})
|
||||
}
|
||||
|
||||
if jsonFlag {
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
return nil
|
||||
}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func stateClean(cmd *cobra.Command, args []string) error {
|
||||
var stateName string
|
||||
if !allFlag {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -88,6 +89,9 @@ func init() {
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
upCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
upCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
upCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -96,6 +100,10 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
if (jsonFlag || yamlFlag) && foregroundMode {
|
||||
return fmt.Errorf("--json/--yaml is not supported with --foreground-mode; use daemon mode")
|
||||
}
|
||||
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
@@ -245,8 +253,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if !profileSwitched {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
return emitUpOutput(cmd, &nbstatus.UpOutput{
|
||||
Status: "already_connected",
|
||||
ProfileName: activeProf.Name,
|
||||
}, "Already connected")
|
||||
}
|
||||
|
||||
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
|
||||
@@ -273,7 +283,31 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil {
|
||||
return fmt.Errorf("daemon up failed: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return emitUpOutput(cmd, &nbstatus.UpOutput{
|
||||
Status: "connected",
|
||||
ProfileName: activeProf.Name,
|
||||
}, "Connected")
|
||||
}
|
||||
|
||||
// emitUpOutput writes the result of an up command in the format requested by
|
||||
// the user (json, yaml, or human-readable text fallback).
|
||||
func emitUpOutput(cmd *cobra.Command, out *nbstatus.UpOutput, textFallback string) error {
|
||||
switch {
|
||||
case jsonFlag:
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
case yamlFlag:
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
default:
|
||||
cmd.Println(textFallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -10,9 +11,35 @@ var (
|
||||
versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the NetBird's client application version",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
cmd.Println(version.NetbirdVersion())
|
||||
v := version.NetbirdVersion()
|
||||
|
||||
switch {
|
||||
case jsonFlag:
|
||||
out := &nbstatus.VersionOutput{Version: v}
|
||||
s, err := out.JSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(s)
|
||||
case yamlFlag:
|
||||
out := &nbstatus.VersionOutput{Version: v}
|
||||
s, err := out.YAML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(s)
|
||||
default:
|
||||
cmd.Println(v)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
versionCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
|
||||
versionCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
|
||||
versionCmd.MarkFlagsMutuallyExclusive("json", "yaml")
|
||||
}
|
||||
|
||||
@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
||||
// Empty disables the firewall.
|
||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
||||
// and the netbird daemon. Default mode also permits anything on the
|
||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
||||
)
|
||||
|
||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
||||
var defaultBlockedPorts = []uint16{53, 853}
|
||||
|
||||
// blockedPorts returns the effective port list, honoring env overrides.
|
||||
// A nil return means the firewall should not be installed.
|
||||
func blockedPorts() []uint16 {
|
||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
||||
return nil
|
||||
}
|
||||
|
||||
override, ok := os.LookupEnv(EnvPorts)
|
||||
if !ok {
|
||||
return defaultBlockedPorts
|
||||
}
|
||||
|
||||
var ports []uint16
|
||||
for _, raw := range strings.Split(override, ",") {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.ParseUint(raw, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
||||
continue
|
||||
}
|
||||
if port == 0 {
|
||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
||||
continue
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
||||
return nil
|
||||
}
|
||||
return ports
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlockedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disable string
|
||||
ports string
|
||||
setPorts bool
|
||||
want []uint16
|
||||
}{
|
||||
{name: "default", want: defaultBlockedPorts},
|
||||
{name: "disabled", disable: "true", want: nil},
|
||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvDisable, tc.disable)
|
||||
if tc.setPorts {
|
||||
t.Setenv(EnvPorts, tc.ports)
|
||||
}
|
||||
got := blockedPorts()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
||||
// outside netbird cannot bypass the configured DNS path.
|
||||
//
|
||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
||||
// a no-op manager.
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
||||
// to call multiple times.
|
||||
type Manager interface {
|
||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
||||
Disable() error
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type noopManager struct{}
|
||||
|
||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
||||
func (noopManager) Disable() error { return nil }
|
||||
|
||||
// New returns a no-op manager on non-Windows platforms.
|
||||
func New() Manager {
|
||||
return noopManager{}
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
||||
)
|
||||
|
||||
type windowsManager struct {
|
||||
mu sync.Mutex
|
||||
// session is the WFP engine handle. Zero when disabled.
|
||||
session uintptr
|
||||
}
|
||||
|
||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ports := blockedPorts()
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.session != 0 {
|
||||
if err := m.disableLocked(); err != nil {
|
||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
strict := strictMode()
|
||||
|
||||
luid, err := luidFromGUID(ifaceGUID)
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
||||
}
|
||||
|
||||
cfg := installConfig{
|
||||
tunLUID: luid,
|
||||
daemonExe: exe,
|
||||
blockedPorts: ports,
|
||||
strict: strict,
|
||||
virtualDNSIP: virtualDNSIP,
|
||||
}
|
||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
||||
session, installErr := installFilters(cfg)
|
||||
if session == 0 {
|
||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
||||
}
|
||||
|
||||
if installErr != nil && strict {
|
||||
_ = closeSession(session)
|
||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
||||
if installErr != nil {
|
||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsManager) Disable() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.disableLocked()
|
||||
}
|
||||
|
||||
func (m *windowsManager) disableLocked() error {
|
||||
if m.session == 0 {
|
||||
return nil
|
||||
}
|
||||
session := m.session
|
||||
m.session = 0
|
||||
if err := closeSession(session); err != nil {
|
||||
return fmt.Errorf("close wfp session: %w", err)
|
||||
}
|
||||
log.Info("dns firewall removed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
||||
// error is logged and nil is returned.
|
||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
||||
if strict {
|
||||
return err
|
||||
}
|
||||
log.Errorf("dns firewall: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a Windows DNS firewall manager backed by WFP.
|
||||
func New() Manager {
|
||||
return &windowsManager{}
|
||||
}
|
||||
|
||||
// strictMode reports whether strict mode is enabled via env.
|
||||
func strictMode() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
||||
return v
|
||||
}
|
||||
|
||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse guid: %w", err)
|
||||
}
|
||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
if rc != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
||||
}
|
||||
return luid, nil
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrictMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
set bool
|
||||
want bool
|
||||
}{
|
||||
{name: "unset", want: false},
|
||||
{name: "true", val: "true", set: true, want: true},
|
||||
{name: "1", val: "1", set: true, want: true},
|
||||
{name: "false", val: "false", set: true, want: false},
|
||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvStrict, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(EnvStrict)
|
||||
}
|
||||
if got := strictMode(); got != tc.want {
|
||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
||||
m := &windowsManager{}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
||||
}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session should remain zero, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
||||
t.Setenv(EnvDisable, "true")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
||||
t.Setenv(EnvPorts, "")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
return &wtFwpmDisplayData0{
|
||||
name: namePtr,
|
||||
description: descriptionPtr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func filterWeight(weight uint8) wtFwpValue0 {
|
||||
return wtFwpValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(weight),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapErr(err error) error {
|
||||
var errno syscall.Errno
|
||||
if !errors.As(err, &errno) {
|
||||
return err
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
||||
}
|
||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
||||
// follow the authorized outbound flow.
|
||||
|
||||
// permitTunInterface installs a permit filter for any traffic whose local
|
||||
// interface is the netbird tunnel.
|
||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT64,
|
||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
||||
}
|
||||
|
||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
||||
// dedicated binary.
|
||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
||||
appID, err := daemonAppID(daemonExe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
||||
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_BLOB_TYPE,
|
||||
value: uintptr(unsafe.Pointer(appID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
||||
}
|
||||
|
||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
||||
// permitTunInterface.
|
||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
||||
if !ip.IsValid() {
|
||||
return fmt.Errorf("invalid address")
|
||||
}
|
||||
|
||||
var addrCond wtFwpmFilterCondition0
|
||||
var layer windows.GUID
|
||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
||||
var v6 wtFwpByteArray16
|
||||
|
||||
if ip.Is4() {
|
||||
v4 := ip.As4()
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT32,
|
||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
||||
} else {
|
||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
||||
value: uintptr(unsafe.Pointer(&v6)),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
||||
}
|
||||
|
||||
conditions := [2]wtFwpmFilterCondition0{
|
||||
addrCond,
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
_ = v6
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
||||
// accumulated; partial coverage is preferred over zero coverage.
|
||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
||||
conditions := [3]wtFwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_UDP),
|
||||
},
|
||||
},
|
||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_TCP),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
||||
}
|
||||
|
||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
||||
// connect layers. v4 and v6 are installed independently: failure on one
|
||||
// layer does not abort the other, and the accumulated errors are returned.
|
||||
// Partial coverage is preferred over zero coverage.
|
||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
||||
layers := [...]struct {
|
||||
layer windows.GUID
|
||||
label string
|
||||
}{
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, l := range layers {
|
||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
continue
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = l.layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
||||
* from wireguard-windows tunnel/firewall.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// installConfig is the input to installFilters.
|
||||
type installConfig struct {
|
||||
tunLUID uint64
|
||||
daemonExe string
|
||||
blockedPorts []uint16
|
||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
||||
strict bool
|
||||
virtualDNSIP netip.Addr
|
||||
}
|
||||
|
||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
||||
// for our session. Both are randomly generated per session.
|
||||
type baseObjects struct {
|
||||
provider windows.GUID
|
||||
filters windows.GUID
|
||||
}
|
||||
|
||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
||||
// firewall filters. Returns a zero session on hard failure (session create,
|
||||
// base objects); a non-zero session with a non-nil error is a partial install
|
||||
// (some per-filter installs failed) and is safe to close.
|
||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Dynamic session: kernel will clean up on process exit even
|
||||
// if we leave the handle dangling here.
|
||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(cfg.blockedPorts) == 0 {
|
||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
||||
}
|
||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
||||
}
|
||||
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
base, err := registerBaseObjects(session)
|
||||
if err != nil {
|
||||
_ = fwpmEngineClose0(session)
|
||||
return 0, fmt.Errorf("register base objects: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if cfg.strict {
|
||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
||||
}
|
||||
}
|
||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
||||
}
|
||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
||||
}
|
||||
|
||||
return session, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// closeSession tears down a WFP session previously opened by installFilters.
|
||||
// All filters owned by the session are removed.
|
||||
func closeSession(session uintptr) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if session == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := fwpmEngineClose0(session); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession() (uintptr, error) {
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
||||
if err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
session := wtFwpmSession0{
|
||||
displayData: *displayData,
|
||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
||||
}
|
||||
var handle uintptr
|
||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
||||
bo := &baseObjects{}
|
||||
var err error
|
||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
provider := wtFwpmProvider0{
|
||||
providerKey: bo.provider,
|
||||
displayData: *displayData,
|
||||
}
|
||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
sublayer := wtFwpmSublayer0{
|
||||
subLayerKey: bo.filters,
|
||||
displayData: *subDisplay,
|
||||
providerKey: &bo.provider,
|
||||
weight: ^uint16(0),
|
||||
}
|
||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return bo, nil
|
||||
}
|
||||
|
||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
var appID *wtFwpByteBlob
|
||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return appID, nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
||||
@@ -1,414 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
||||
|
||||
wtFwpBitmapArray64_Size = 8
|
||||
|
||||
wtFwpByteArray16_Size = 16
|
||||
|
||||
wtFwpByteArray6_Size = 6
|
||||
|
||||
wtFwpmAction0_Size = 20
|
||||
wtFwpmAction0_filterType_Offset = 4
|
||||
|
||||
wtFwpV4AddrAndMask_Size = 8
|
||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
||||
|
||||
wtFwpV6AddrAndMask_Size = 17
|
||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
||||
)
|
||||
|
||||
type wtFwpActionFlag uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
||||
)
|
||||
|
||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
||||
type wtFwpActionType uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
||||
)
|
||||
|
||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
||||
type wtFwpByteBlob struct {
|
||||
size uint32
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
||||
type wtFwpMatchType uint32
|
||||
|
||||
const (
|
||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
||||
)
|
||||
|
||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
||||
type wtFwpmAction0 struct {
|
||||
_type wtFwpActionType
|
||||
filterType windows.GUID // Windows type: GUID
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
||||
Data1: 0x4cd62a49,
|
||||
Data2: 0x59c3,
|
||||
Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
||||
Data1: 0xb235ae9a,
|
||||
Data2: 0x1d64,
|
||||
Data3: 0x49b8,
|
||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
||||
Data1: 0x3971ef2b,
|
||||
Data2: 0x623e,
|
||||
Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
||||
Data1: 0x0c1ba1af,
|
||||
Data2: 0x5765,
|
||||
Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
||||
Data1: 0xc35a604d,
|
||||
Data2: 0xd22b,
|
||||
Data3: 0x4e1a,
|
||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
||||
Data1: 0xd78e1e87,
|
||||
Data2: 0x8644,
|
||||
Data3: 0x4ea5,
|
||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
||||
}
|
||||
|
||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
||||
Data1: 0xaf043a0a,
|
||||
Data2: 0xb34d,
|
||||
Data3: 0x4f86,
|
||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
||||
}
|
||||
|
||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
||||
Data1: 0xd9ee00de,
|
||||
Data2: 0xc1ef,
|
||||
Data3: 0x4617,
|
||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
||||
}
|
||||
|
||||
var (
|
||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
||||
)
|
||||
|
||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
||||
Data1: 0x7bc43cbf,
|
||||
Data2: 0x37ba,
|
||||
Data3: 0x45f1,
|
||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
||||
}
|
||||
|
||||
type wtFwpmL2Flags uint32
|
||||
|
||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
||||
|
||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
||||
Data1: 0x632ce23b,
|
||||
Data2: 0x5167,
|
||||
Data3: 0x435c,
|
||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
||||
}
|
||||
|
||||
type wtFwpmFlags uint32
|
||||
|
||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
||||
|
||||
// Defined in fwpmtypes.h
|
||||
type wtFwpmFilterFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
||||
)
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
||||
Data1: 0xc38d57d1,
|
||||
Data2: 0x05a7,
|
||||
Data3: 0x4c33,
|
||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
||||
}
|
||||
|
||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7,
|
||||
Data2: 0xf4b5,
|
||||
Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
||||
Data1: 0x4a72393b,
|
||||
Data2: 0x319f,
|
||||
Data3: 0x44bc,
|
||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
||||
}
|
||||
|
||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
||||
Data1: 0xa3b42c97,
|
||||
Data2: 0x9f04,
|
||||
Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
|
||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0x94c44912,
|
||||
Data2: 0x9d6f,
|
||||
Data3: 0x4ebf,
|
||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
||||
}
|
||||
|
||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0xd4220bd3,
|
||||
Data2: 0x62ce,
|
||||
Data3: 0x4f08,
|
||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
||||
}
|
||||
|
||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
||||
type wtFwpBitmapArray64 struct {
|
||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
||||
type wtFwpByteArray6 struct {
|
||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
||||
type wtFwpByteArray16 struct {
|
||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
||||
}
|
||||
|
||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
||||
type wtFwpConditionValue0 wtFwpValue0
|
||||
|
||||
// FWP_DATA_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
||||
type wtFwpDataType uint
|
||||
|
||||
const (
|
||||
cFWP_EMPTY wtFwpDataType = 0
|
||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
||||
)
|
||||
|
||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
||||
type wtFwpV4AddrAndMask struct {
|
||||
addr uint32
|
||||
mask uint32
|
||||
}
|
||||
|
||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
||||
type wtFwpV6AddrAndMask struct {
|
||||
addr [16]uint8
|
||||
prefixLength uint8
|
||||
}
|
||||
|
||||
// FWP_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
||||
type wtFwpValue0 struct {
|
||||
_type wtFwpDataType
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
||||
type wtFwpmDisplayData0 struct {
|
||||
name *uint16 // Windows type: *wchar_t
|
||||
description *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
||||
type wtFwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // Windows type: GUID
|
||||
matchType wtFwpMatchType
|
||||
conditionValue wtFwpConditionValue0
|
||||
}
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
||||
type wtFwpProvider0 struct {
|
||||
providerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
type wtFwpmSessionFlagsValue uint32
|
||||
|
||||
const (
|
||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
||||
type wtFwpmSession0 struct {
|
||||
sessionKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32 // Windows type: DWORD
|
||||
sid *windows.SID
|
||||
username *uint16 // Windows type: *wchar_t
|
||||
kernelMode uint8 // Windows type: BOOL
|
||||
}
|
||||
|
||||
type wtFwpmSublayerFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
||||
type wtFwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSublayerFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
weight uint16
|
||||
}
|
||||
|
||||
// Defined in rpcdce.h
|
||||
type wtRpcCAuthN uint32
|
||||
|
||||
const (
|
||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
||||
type wtFwpmProvider0 struct {
|
||||
providerKey windows.GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16
|
||||
}
|
||||
|
||||
type wtIPProto uint32
|
||||
|
||||
const (
|
||||
cIPPROTO_ICMP wtIPProto = 1
|
||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
||||
cIPPROTO_TCP wtIPProto = 6
|
||||
cIPPROTO_UDP wtIPProto = 17
|
||||
)
|
||||
|
||||
const (
|
||||
cFWP_ACTRL_MATCH_FILTER = 1
|
||||
)
|
||||
@@ -1,92 +0,0 @@
|
||||
//go:build windows && (386 || arm)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 8
|
||||
wtFwpByteBlob_data_Offset = 4
|
||||
|
||||
wtFwpConditionValue0_Size = 8
|
||||
wtFwpConditionValue0_uint8_Offset = 4
|
||||
|
||||
wtFwpmDisplayData0_Size = 8
|
||||
wtFwpmDisplayData0_description_Offset = 4
|
||||
|
||||
wtFwpmFilter0_Size = 152
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 24
|
||||
wtFwpmFilter0_providerKey_Offset = 28
|
||||
wtFwpmFilter0_providerData_Offset = 32
|
||||
wtFwpmFilter0_layerKey_Offset = 40
|
||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
||||
wtFwpmFilter0_weight_Offset = 72
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
||||
wtFwpmFilter0_filterCondition_Offset = 84
|
||||
wtFwpmFilter0_action_Offset = 88
|
||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
||||
wtFwpmFilter0_reserved_Offset = 128
|
||||
wtFwpmFilter0_filterID_Offset = 136
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
||||
|
||||
wtFwpmFilterCondition0_Size = 28
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
||||
|
||||
wtFwpmSession0_Size = 48
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 24
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
||||
wtFwpmSession0_processId_Offset = 32
|
||||
wtFwpmSession0_sid_Offset = 36
|
||||
wtFwpmSession0_username_Offset = 40
|
||||
wtFwpmSession0_kernelMode_Offset = 44
|
||||
|
||||
wtFwpmSublayer0_Size = 44
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 24
|
||||
wtFwpmSublayer0_providerKey_Offset = 28
|
||||
wtFwpmSublayer0_providerData_Offset = 32
|
||||
wtFwpmSublayer0_weight_Offset = 40
|
||||
|
||||
wtFwpProvider0_Size = 40
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 24
|
||||
wtFwpProvider0_providerData_Offset = 28
|
||||
wtFwpProvider0_serviceName_Offset = 36
|
||||
|
||||
wtFwpTokenInformation_Size = 16
|
||||
|
||||
wtFwpValue0_Size = 8
|
||||
wtFwpValue0_value_Offset = 4
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
offset2 [4]byte // Layout correction field
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
//go:build windows && (amd64 || arm64)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 16
|
||||
wtFwpByteBlob_data_Offset = 8
|
||||
|
||||
wtFwpConditionValue0_Size = 16
|
||||
wtFwpConditionValue0_uint8_Offset = 8
|
||||
|
||||
wtFwpmDisplayData0_Size = 16
|
||||
wtFwpmDisplayData0_description_Offset = 8
|
||||
|
||||
wtFwpmFilter0_Size = 200
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 32
|
||||
wtFwpmFilter0_providerKey_Offset = 40
|
||||
wtFwpmFilter0_providerData_Offset = 48
|
||||
wtFwpmFilter0_layerKey_Offset = 64
|
||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
||||
wtFwpmFilter0_weight_Offset = 96
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
||||
wtFwpmFilter0_filterCondition_Offset = 120
|
||||
wtFwpmFilter0_action_Offset = 128
|
||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
||||
wtFwpmFilter0_reserved_Offset = 168
|
||||
wtFwpmFilter0_filterID_Offset = 176
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
||||
|
||||
wtFwpmFilterCondition0_Size = 40
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
||||
|
||||
wtFwpmSession0_Size = 72
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 32
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
||||
wtFwpmSession0_processId_Offset = 40
|
||||
wtFwpmSession0_sid_Offset = 48
|
||||
wtFwpmSession0_username_Offset = 56
|
||||
wtFwpmSession0_kernelMode_Offset = 64
|
||||
|
||||
wtFwpmSublayer0_Size = 72
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 32
|
||||
wtFwpmSublayer0_providerKey_Offset = 40
|
||||
wtFwpmSublayer0_providerData_Offset = 48
|
||||
wtFwpmSublayer0_weight_Offset = 64
|
||||
|
||||
wtFwpProvider0_Size = 64
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 32
|
||||
wtFwpProvider0_providerData_Offset = 40
|
||||
wtFwpProvider0_serviceName_Offset = 56
|
||||
|
||||
wtFwpValue0_Size = 16
|
||||
wtFwpValue0_value_Offset = 8
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
|
||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
||||
)
|
||||
|
||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
@@ -75,7 +74,6 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
dnsFirewall dnsfw.Manager
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
@@ -96,9 +94,8 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
}
|
||||
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
@@ -279,8 +276,16 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if err := r.applyRouteAll(config); err != nil {
|
||||
return err
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
@@ -322,35 +327,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
||||
if config.RouteAll {
|
||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("dns firewall: %w", err)
|
||||
}
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
if !r.routingAll {
|
||||
return nil
|
||||
}
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
@@ -537,10 +513,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
@@ -36,9 +34,8 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
@@ -137,9 +134,8 @@ func TestNRPTDomainBatching(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -64,6 +64,13 @@
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
|
||||
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
|
||||
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKCU"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<StandardDirectory Id="CommonAppDataFolder">
|
||||
@@ -76,10 +83,28 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
|
||||
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
|
||||
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
|
||||
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
|
||||
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
|
||||
</Component>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
<ComponentRef Id="NetbirdAutoStart" />
|
||||
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
|
||||
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
|
||||
@@ -384,6 +384,350 @@ func (o *OutputOverview) YAML() (string, error) {
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// DownOutput is the structured result of a `netbird down` command.
|
||||
type DownOutput struct {
|
||||
Status string `json:"status" yaml:"status"`
|
||||
}
|
||||
|
||||
// UpOutput is the final structured result of a `netbird up` command.
|
||||
type UpOutput struct {
|
||||
Status string `json:"status" yaml:"status"` // "connected" | "already_connected"
|
||||
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the UpOutput as a JSON string.
|
||||
func (o *UpOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the UpOutput as a YAML string.
|
||||
func (o *UpOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// SSOEvent is emitted before the final UpOutput when interactive SSO is required.
|
||||
// It carries the verification URL and user code so callers can surface them
|
||||
// (e.g. to a Slack channel) without parsing free-form text.
|
||||
type SSOEvent struct {
|
||||
Event string `json:"event" yaml:"event"` // "sso_required"
|
||||
VerificationURIComplete string `json:"verificationUriComplete" yaml:"verificationUriComplete"`
|
||||
UserCode string `json:"userCode,omitempty" yaml:"userCode,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the SSOEvent as a JSON string.
|
||||
func (e *SSOEvent) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the SSOEvent as a YAML string.
|
||||
func (e *SSOEvent) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(e)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// NetworkOutput is one row of a networks-list response.
|
||||
type NetworkOutput struct {
|
||||
ID string `json:"id" yaml:"id"`
|
||||
Range string `json:"range,omitempty" yaml:"range,omitempty"`
|
||||
Domains []string `json:"domains,omitempty" yaml:"domains,omitempty"`
|
||||
ResolvedIPs map[string][]string `json:"resolvedIps,omitempty" yaml:"resolvedIps,omitempty"`
|
||||
Selected bool `json:"selected" yaml:"selected"`
|
||||
}
|
||||
|
||||
// NetworksListOutput is the structured result of `netbird networks list`.
|
||||
type NetworksListOutput struct {
|
||||
Networks []NetworkOutput `json:"networks" yaml:"networks"`
|
||||
}
|
||||
|
||||
// JSON returns the NetworksListOutput as a JSON string.
|
||||
func (o *NetworksListOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the NetworksListOutput as a YAML string.
|
||||
func (o *NetworksListOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// ForwardingRuleOutput is one row of a forwarding-list response.
|
||||
// DestinationPort and TranslatedPort are pointers so that absent port info
|
||||
// (oneof not set in the protobuf) surfaces as a missing field rather than a
|
||||
// human-readable sentinel string.
|
||||
type ForwardingRuleOutput struct {
|
||||
TranslatedAddress string `json:"translatedAddress" yaml:"translatedAddress"`
|
||||
TranslatedHostname string `json:"translatedHostname,omitempty" yaml:"translatedHostname,omitempty"`
|
||||
Protocol string `json:"protocol" yaml:"protocol"`
|
||||
DestinationPort *string `json:"destinationPort,omitempty" yaml:"destinationPort,omitempty"`
|
||||
TranslatedPort *string `json:"translatedPort,omitempty" yaml:"translatedPort,omitempty"`
|
||||
}
|
||||
|
||||
// ForwardingListOutput is the structured result of `netbird forwarding list`.
|
||||
type ForwardingListOutput struct {
|
||||
Rules []ForwardingRuleOutput `json:"rules" yaml:"rules"`
|
||||
}
|
||||
|
||||
// JSON returns the ForwardingListOutput as a JSON string.
|
||||
func (o *ForwardingListOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the ForwardingListOutput as a YAML string.
|
||||
func (o *ForwardingListOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// ProfileOutput is one row of a profile-list response.
|
||||
type ProfileOutput struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Active bool `json:"active" yaml:"active"`
|
||||
}
|
||||
|
||||
// ProfileListOutput is the structured result of `netbird profile list`.
|
||||
type ProfileListOutput struct {
|
||||
Profiles []ProfileOutput `json:"profiles" yaml:"profiles"`
|
||||
}
|
||||
|
||||
// JSON returns the ProfileListOutput as a JSON string.
|
||||
func (o *ProfileListOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the ProfileListOutput as a YAML string.
|
||||
func (o *ProfileListOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// StateOutput is one row of a state-list response.
|
||||
type StateOutput struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
}
|
||||
|
||||
// StateListOutput is the structured result of `netbird state list`.
|
||||
type StateListOutput struct {
|
||||
States []StateOutput `json:"states" yaml:"states"`
|
||||
}
|
||||
|
||||
// JSON returns the StateListOutput as a JSON string.
|
||||
func (o *StateListOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the StateListOutput as a YAML string.
|
||||
func (o *StateListOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// DeregisterOutput is the structured result of `netbird deregister`.
|
||||
type DeregisterOutput struct {
|
||||
Status string `json:"status" yaml:"status"` // "deregistered"
|
||||
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the DeregisterOutput as a JSON string.
|
||||
func (o *DeregisterOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the DeregisterOutput as a YAML string.
|
||||
func (o *DeregisterOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// LoginOutput is the structured result of `netbird login` after successful auth.
|
||||
type LoginOutput struct {
|
||||
Status string `json:"status" yaml:"status"` // "logged_in"
|
||||
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the LoginOutput as a JSON string.
|
||||
func (o *LoginOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the LoginOutput as a YAML string.
|
||||
func (o *LoginOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// ProfileMutationOutput is the structured result of `netbird profile add`,
|
||||
// `remove`, or `select`.
|
||||
type ProfileMutationOutput struct {
|
||||
Status string `json:"status" yaml:"status"` // "added" | "removed" | "selected"
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
}
|
||||
|
||||
// JSON returns the ProfileMutationOutput as a JSON string.
|
||||
func (o *ProfileMutationOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the ProfileMutationOutput as a YAML string.
|
||||
func (o *ProfileMutationOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// NetworksMutationOutput is the structured result of `netbird networks select`
|
||||
// or `netbird networks deselect`.
|
||||
type NetworksMutationOutput struct {
|
||||
Status string `json:"status" yaml:"status"` // "selected" | "deselected"
|
||||
Networks []string `json:"networks,omitempty" yaml:"networks,omitempty"`
|
||||
All bool `json:"all,omitempty" yaml:"all,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the NetworksMutationOutput as a JSON string.
|
||||
func (o *NetworksMutationOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the NetworksMutationOutput as a YAML string.
|
||||
func (o *NetworksMutationOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// DebugBundleOutput is the structured result of `netbird debug bundle`.
|
||||
type DebugBundleOutput struct {
|
||||
Path string `json:"path" yaml:"path"`
|
||||
UploadedKey string `json:"uploadedKey,omitempty" yaml:"uploadedKey,omitempty"`
|
||||
}
|
||||
|
||||
// JSON returns the DebugBundleOutput as a JSON string.
|
||||
func (o *DebugBundleOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the DebugBundleOutput as a YAML string.
|
||||
func (o *DebugBundleOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// VersionOutput is the structured result of `netbird version`.
|
||||
type VersionOutput struct {
|
||||
Version string `json:"version" yaml:"version"`
|
||||
}
|
||||
|
||||
// JSON returns the VersionOutput as a JSON string.
|
||||
func (o *VersionOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the VersionOutput as a YAML string.
|
||||
func (o *VersionOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// JSON returns the DownOutput as a JSON string.
|
||||
func (o *DownOutput) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the DownOutput as a YAML string.
|
||||
func (o *DownOutput) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// GeneralSummary returns a general summary of the status overview.
|
||||
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
@@ -47,6 +48,11 @@ type EphemeralManager struct {
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
|
||||
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||
// no-op when the receiver is nil so deployments without an app
|
||||
// metrics provider work unchanged.
|
||||
metrics *telemetry.EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||
// reflected in the gauge. Pass nil to detach.
|
||||
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||
e.peersLock.Lock()
|
||||
e.metrics = m
|
||||
e.peersLock.Unlock()
|
||||
}
|
||||
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.removePeer(peer.ID)
|
||||
if e.removePeer(peer.ID) {
|
||||
e.metrics.DecPending(1)
|
||||
}
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
e.metrics.IncPending()
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
e.metrics.AddPending(int64(len(peers)))
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
// Drop the gauge by the number of entries we just took off the list,
|
||||
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||
// list invariant is what the gauge tracks; failed delete batches are
|
||||
// counted separately via CountCleanupError so we can still see them.
|
||||
if len(deletePeers) > 0 {
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.DecPending(int64(len(deletePeers)))
|
||||
}
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
e.metrics.CountCleanupError()
|
||||
continue
|
||||
}
|
||||
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) removePeer(id string) {
|
||||
// removePeer drops the entry from the linked list. Returns true if a
|
||||
// matching entry was found and removed so callers can keep the pending
|
||||
// metric gauge in sync.
|
||||
func (e *EphemeralManager) removePeer(id string) bool {
|
||||
if e.headPeer == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
|
||||
@@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
if metrics := s.Metrics(); metrics != nil {
|
||||
em.SetMetrics(metrics.EphemeralPeersMetrics())
|
||||
}
|
||||
return em
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
@@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
type hookingStream struct {
|
||||
grpc.ServerStream
|
||||
onSend func(*proto.GetMappingUpdateResponse)
|
||||
}
|
||||
|
||||
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
if s.onSend != nil {
|
||||
s.onSend(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6
|
||||
const ttl = 100 * time.Millisecond
|
||||
const sendDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
var validateErrs []error
|
||||
stream := &hookingStream{
|
||||
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||
for _, m := range resp.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||
}
|
||||
}
|
||||
time.Sleep(sendDelay)
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
uncanceledCTX := context.WithoutCancel(ctx)
|
||||
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
|
||||
@@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context,
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Canonicalize the incoming range so a caller-supplied prefix with host bits
|
||||
// (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net.
|
||||
newSettings.NetworkRange = newSettings.NetworkRange.Masked()
|
||||
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
var effectiveOldNetworkRange netip.Prefix
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
@@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
// No lock: the transaction already holds Settings(Update), and network.Net is
|
||||
// only mutated by reallocateAccountPeerIPs, which is reachable only through
|
||||
// this same code path. A Share lock here would extend an unnecessary row lock
|
||||
// and complicate ordering against updatePeerIPv6InTransaction.
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account network: %w", err)
|
||||
}
|
||||
effectiveOldNetworkRange = prefixFromIPNet(network.Net)
|
||||
|
||||
if oldSettings.Extra != nil && newSettings.Extra != nil &&
|
||||
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
|
||||
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
|
||||
@@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -396,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
|
||||
}
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
eventMeta := map[string]any{
|
||||
"old_network_range": oldSettings.NetworkRange.String(),
|
||||
"old_network_range": effectiveOldNetworkRange.String(),
|
||||
"new_network_range": newSettings.NetworkRange.String(),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
@@ -443,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool {
|
||||
return !slices.Equal(oldGroups, newGroups)
|
||||
}
|
||||
|
||||
// prefixFromIPNet returns the overlay prefix actually allocated on the account
|
||||
// network, or an invalid prefix if none is set. Settings.NetworkRange is a
|
||||
// user-facing override that is empty on legacy accounts, so the effective
|
||||
// range must be read from network.Net to compare against an incoming update.
|
||||
func prefixFromIPNet(ipNet net.IPNet) netip.Prefix {
|
||||
if ipNet.IP == nil {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(ipNet.IP)
|
||||
if !ok {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
ones, _ := ipNet.Mask.Size()
|
||||
return netip.PrefixFrom(addr.Unmap(), ones)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
@@ -1837,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
|
||||
// network map and then marks the peer connected with a session token
|
||||
// derived from syncTime (the moment the gRPC stream opened). Any
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||
if err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
|
||||
// peer disconnected only when the stored SessionStartedAt matches the
|
||||
// nanosecond token derived from streamStartTime — i.e. only when this
|
||||
// is the stream that currently owns the peer's session. A mismatch
|
||||
// means a newer stream has already replaced us, so the disconnect is
|
||||
// dropped.
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if peer.Status.LastSeen.After(streamStartTime) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||
if err != nil {
|
||||
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -61,7 +61,8 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
|
||||
@@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// OnPeerDisconnected mocks base method.
|
||||
|
||||
@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err, "unable to get peer")
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal the token we passed in")
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
@@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
|
||||
})
|
||||
|
||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||
// Older stream tries to mark disconnect with its own (older) session token —
|
||||
// fencing kicks in and the write is dropped.
|
||||
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected,
|
||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||
"peer should remain connected because the stored session is newer than the disconnect token")
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should still hold the winning stream's token")
|
||||
})
|
||||
|
||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
|
||||
// fencing protocol under contention: many goroutines race to mark the
|
||||
// same peer connected with distinct session tokens at the same time.
|
||||
// The contract is that the highest token always wins and is what remains
|
||||
// in the store, regardless of execution order.
|
||||
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
const workers = 16
|
||||
base := time.Now().UTC().UnixNano()
|
||||
tokens := make([]int64, workers)
|
||||
for i := range tokens {
|
||||
// Spread tokens by 1ms so the comparison is unambiguous; the
|
||||
// largest is index workers-1.
|
||||
tokens[i] = base + int64(i)*int64(time.Millisecond)
|
||||
}
|
||||
expected := tokens[workers-1]
|
||||
|
||||
var ready sync.WaitGroup
|
||||
ready.Add(workers)
|
||||
var start sync.WaitGroup
|
||||
start.Add(1)
|
||||
var done sync.WaitGroup
|
||||
done.Add(workers)
|
||||
|
||||
// require.* calls t.FailNow which is documented as unsafe from
|
||||
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
|
||||
// races with the WaitGroup). Collect errors here and assert from the
|
||||
// main goroutine after done.Wait().
|
||||
errs := make(chan error, workers)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
token := tokens[i]
|
||||
go func() {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
}()
|
||||
}
|
||||
|
||||
ready.Wait()
|
||||
start.Done()
|
||||
done.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err, "MarkPeerConnected must not error under contention")
|
||||
}
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected after the race")
|
||||
require.Equal(t, expected, peer.Status.SessionStartedAt,
|
||||
"the largest token must win regardless of execution order")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3970,6 +4049,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against
|
||||
// peer IP reallocation when a settings update carries the network range that is already
|
||||
// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net
|
||||
// holds the actual allocated overlay; the dashboard backfills the GET response from
|
||||
// network.Net and echoes the value back on PUT, so the diff must be against the
|
||||
// effective range to avoid renumbering every peer on an unrelated settings change.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset")
|
||||
|
||||
network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated")
|
||||
addr, ok := netip.AddrFromSlice(network.Net.IP)
|
||||
require.True(t, ok)
|
||||
ones, _ := network.Net.Mask.Size()
|
||||
effective := netip.PrefixFrom(addr.Unmap(), ones)
|
||||
require.True(t, effective.IsValid())
|
||||
|
||||
before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP}
|
||||
|
||||
// Round-trip the effective range as if the dashboard echoed back the GET-backfilled value.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: effective,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, len(before))
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID)
|
||||
}
|
||||
|
||||
// Carrying the same range with host bits set must also be a no-op once canonicalized.
|
||||
hostBitsForm := netip.PrefixFrom(peer1.IP, ones)
|
||||
require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: hostBitsForm,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID)
|
||||
}
|
||||
|
||||
// Omitting NetworkRange (invalid prefix) must also be a no-op.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID)
|
||||
}
|
||||
|
||||
// Sanity: an actually different range still triggers reallocation.
|
||||
newRange := netip.MustParsePrefix("100.99.0.0/16")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: newRange,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP)
|
||||
assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
@@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
|
||||
// path (for example, /oauth2/device/callback) when completing the device flow, so
|
||||
// include it explicitly in addition to the legacy bare path and absolute URL.
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
|
||||
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
|
||||
|
||||
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
|
||||
dashboardRedirectURIs := c.DashboardRedirectURIs
|
||||
@@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
|
||||
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
|
||||
|
||||
redirectURIs := make([]string, 0)
|
||||
redirectURIs = append(redirectURIs, cliRedirectURIs...)
|
||||
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
@@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
},
|
||||
},
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
@@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func issuerRelativeDeviceCallback(issuer string) string {
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil || u.Path == "" {
|
||||
return "/device/callback"
|
||||
}
|
||||
return path.Join(u.Path, "/device/callback")
|
||||
}
|
||||
|
||||
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
|
||||
// and because Dex only allows exact matches, we need to make sure we always have both
|
||||
// versions of each provided uri
|
||||
@@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
}
|
||||
|
||||
func validSessionCookieEncryptionKeyLength(length int) bool {
|
||||
|
||||
@@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "https://example.com/oauth2",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(t.TempDir(), "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
var cliRedirectURIs []string
|
||||
for _, client := range yamlConfig.StaticClients {
|
||||
if client.ID == staticClientCLI {
|
||||
cliRedirectURIs = client.RedirectURIs
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, cliRedirectURIs)
|
||||
assert.Contains(t, cliRedirectURIs, "/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
@@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
|
||||
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
|
||||
// disconnect events. Falls through to nil when no hook is set, which
|
||||
// is the original behaviour.
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -63,56 +63,64 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
var err error
|
||||
var skipped bool
|
||||
// MarkPeerConnected marks a peer as connected with optimistic-locked
|
||||
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
|
||||
// is the start time of the gRPC sync stream that owns this update,
|
||||
// expressed as Unix nanoseconds — only the call whose token is greater
|
||||
// than what's stored wins. LastSeen is written by the database itself;
|
||||
// we never pass it down.
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
}()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
|
||||
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||
skipped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
|
||||
return err
|
||||
})
|
||||
if skipped {
|
||||
}
|
||||
|
||||
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
expired := peer.Status != nil && peer.Status.LoginExpired
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -120,41 +128,60 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = syncTime
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
newStatus.LoginExpired = false
|
||||
}
|
||||
peer.Status = newStatus
|
||||
// MarkPeerDisconnected marks a peer as disconnected, but only when the
|
||||
// stored session token matches the one passed in. A mismatch means a
|
||||
// newer stream has already taken ownership of the peer — disconnects from
|
||||
// the older stream are ignored. LastSeen is written by the database.
|
||||
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
|
||||
}()
|
||||
|
||||
if geo != nil && realIP != nil {
|
||||
location, err := geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = transaction.SavePeerLocation(ctx, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
|
||||
return err
|
||||
}
|
||||
|
||||
return oldStatus.LoginExpired, nil
|
||||
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
|
||||
peer.ID, sessionStartedAt)
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return
|
||||
}
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
|
||||
@@ -74,8 +74,19 @@ type ProxyMeta struct {
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
// LastSeen is the last time peer was connected to the management service
|
||||
// LastSeen is the last time the peer status was updated (i.e. the last
|
||||
// time we observed the peer being alive on a sync stream). Written by
|
||||
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
|
||||
LastSeen time.Time
|
||||
// SessionStartedAt records when the currently-active sync stream began,
|
||||
// stored as Unix nanoseconds. It acts as the optimistic-locking token
|
||||
// for status updates: a stream is only allowed to mutate the peer's
|
||||
// status when its own token strictly exceeds the stored token (when connecting)
|
||||
// or matches it exactly (for disconnects). Zero means "no
|
||||
// active session". Integer nanoseconds are used so equality is
|
||||
// precision-safe across drivers, and so the predicates compose to a
|
||||
// single bigint comparison.
|
||||
SessionStartedAt int64
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
// LoginExpired
|
||||
@@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
// Copy PeerStatus
|
||||
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
|
||||
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
|
||||
// reset the fencing token to zero — that would let any subsequent
|
||||
// SavePeerStatus write reopen the optimistic-lock window.
|
||||
func (p *PeerStatus) Copy() *PeerStatus {
|
||||
return &PeerStatus{
|
||||
LastSeen: p.LastSeen,
|
||||
SessionStartedAt: p.SessionStartedAt,
|
||||
Connected: p.Connected,
|
||||
LoginExpired: p.LoginExpired,
|
||||
RequiresApproval: p.RequiresApproval,
|
||||
|
||||
@@ -498,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
fieldsToUpdate := []string{
|
||||
"peer_status_last_seen", "peer_status_connected",
|
||||
"peer_status_login_expired", "peer_status_required_approval",
|
||||
"peer_status_last_seen", "peer_status_session_started_at",
|
||||
"peer_status_connected", "peer_status_login_expired",
|
||||
"peer_status_requires_approval",
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
@@ -516,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
|
||||
// The peer is marked connected with the given session token only when
|
||||
// the stored SessionStartedAt is strictly smaller than the incoming
|
||||
// one — equivalently, when no newer stream has already taken ownership.
|
||||
// The sentinel zero (set on peer creation or after a disconnect) counts
|
||||
// as the smallest possible token. This is the write half of the
|
||||
// fencing protocol described on PeerStatus.SessionStartedAt.
|
||||
//
|
||||
// The post-write side effects in the caller — geo lookup,
|
||||
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
|
||||
// OnPeersUpdated — all run AFTER this method returns and are deliberately
|
||||
// outside the database write so they cannot extend the row-lock window.
|
||||
//
|
||||
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
|
||||
// moment the row is written. The caller never supplies LastSeen because
|
||||
// the value would otherwise drift under lock contention — a Go-side
|
||||
// time.Now() taken before the write can land minutes later than the
|
||||
// actual UPDATE under load, which previously caused real ordering bugs.
|
||||
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at < ?", newSessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": true,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": newSessionStartedAt,
|
||||
"peer_status_login_expired": false,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
|
||||
// The peer is marked disconnected only when the stored SessionStartedAt
|
||||
// matches the incoming token — meaning the stream that owns the current
|
||||
// session is the one ending. If a newer stream has already replaced the
|
||||
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
|
||||
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
|
||||
//
|
||||
// A zero sessionStartedAt is rejected at the call site; the underlying
|
||||
// WHERE on equality would otherwise match every never-connected peer.
|
||||
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
if sessionStartedAt == 0 {
|
||||
return false, nil
|
||||
}
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at = ?", sessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": false,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": int64(0),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
@@ -1723,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
|
||||
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
|
||||
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
||||
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
||||
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
|
||||
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
|
||||
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
|
||||
FROM peers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1738,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
lastLogin, createdAt sql.NullTime
|
||||
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||
peerStatusLastSeen sql.NullTime
|
||||
peerStatusSessionStartedAt sql.NullInt64
|
||||
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
|
||||
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
|
||||
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
||||
@@ -1752,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
|
||||
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
||||
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
|
||||
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
||||
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
|
||||
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
|
||||
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
|
||||
&proxyEmbedded, &proxyCluster, &ipv6)
|
||||
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
@@ -1780,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
if peerStatusLastSeen.Valid {
|
||||
p.Status.LastSeen = peerStatusLastSeen.Time
|
||||
}
|
||||
if peerStatusSessionStartedAt.Valid {
|
||||
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
|
||||
}
|
||||
if peerStatusConnected.Valid {
|
||||
p.Status.Connected = peerStatusConnected.Bool
|
||||
}
|
||||
|
||||
@@ -167,6 +167,21 @@ type Store interface {
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
|
||||
// given session token, but only when the stored SessionStartedAt is
|
||||
// strictly less than newSessionStartedAt (the sentinel zero counts as
|
||||
// "older"). LastSeen is recorded by the database at the moment the
|
||||
// row is updated — never by the caller — so it always reflects the
|
||||
// real write time even under lock contention.
|
||||
// Returns true when the update happened, false when this stream lost
|
||||
// the race against a newer session.
|
||||
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
|
||||
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
|
||||
// resets SessionStartedAt to zero, but only when the stored
|
||||
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
|
||||
@@ -2878,6 +2878,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession mocks base method.
|
||||
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession mocks base method.
|
||||
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// SavePolicy mocks base method.
|
||||
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -16,6 +16,8 @@ type AccountManagerMetrics struct {
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
peerStatusUpdateCounter metric.Int64Counter
|
||||
peerStatusUpdateDurationMs metric.Float64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@@ -64,6 +66,24 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected
|
||||
peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000,
|
||||
),
|
||||
metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
@@ -71,10 +91,35 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
updateAccountPeersCounter: updateAccountPeersCounter,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
||||
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// PeerStatusOperation labels the kind of fence-locked peer status write.
|
||||
type PeerStatusOperation string
|
||||
|
||||
// PeerStatusOutcome labels how a fence-locked peer status write resolved.
|
||||
type PeerStatusOutcome string
|
||||
|
||||
const (
|
||||
PeerStatusConnect PeerStatusOperation = "connect"
|
||||
PeerStatusDisconnect PeerStatusOperation = "disconnect"
|
||||
|
||||
// PeerStatusApplied — the fence WHERE matched and the UPDATE landed.
|
||||
PeerStatusApplied PeerStatusOutcome = "applied"
|
||||
// PeerStatusStale — the fence WHERE rejected the write because a
|
||||
// newer session has already taken ownership (connect: stored token
|
||||
// >= incoming; disconnect: stored token != incoming).
|
||||
PeerStatusStale PeerStatusOutcome = "stale"
|
||||
// PeerStatusError — the store returned a non-NotFound error.
|
||||
PeerStatusError PeerStatusOutcome = "error"
|
||||
// PeerStatusPeerNotFound — the peer lookup failed (the peer was
|
||||
// deleted between the gRPC sync handshake and the status write).
|
||||
PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found"
|
||||
)
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
@@ -104,3 +149,23 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeerStatusUpdate increments the connect/disconnect counter,
|
||||
// labeled by operation and outcome. Both labels are bounded enums.
|
||||
func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) {
|
||||
metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("operation", string(op)),
|
||||
attribute.String("outcome", string(outcome)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// RecordPeerStatusUpdateDuration records the wall-clock time spent
|
||||
// running a peer status update (including post-write side effects),
|
||||
// labeled by operation.
|
||||
func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) {
|
||||
metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6,
|
||||
metric.WithAttributes(attribute.String("operation", string(op))),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type MockAppMetrics struct {
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@@ -103,6 +104,14 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface
|
||||
func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
if mock.EphemeralPeersMetricsFunc != nil {
|
||||
return mock.EphemeralPeersMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@@ -114,6 +123,7 @@ type AppMetrics interface {
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
EphemeralPeersMetrics() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
@@ -129,6 +139,7 @@ type defaultAppMetrics struct {
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
ephemeralMetrics *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@@ -161,6 +172,11 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop
|
||||
func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
return appMetrics.ephemeralMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@@ -245,6 +261,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -254,6 +275,7 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -290,6 +312,11 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -300,5 +327,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
115
management/server/telemetry/ephemeral_metrics.go
Normal file
115
management/server/telemetry/ephemeral_metrics.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
|
||||
// many peers are currently scheduled for deletion, how many tick runs
|
||||
// the cleaner has performed, how many peers it has removed, and how
|
||||
// many delete batches failed.
|
||||
type EphemeralPeersMetrics struct {
|
||||
ctx context.Context
|
||||
|
||||
pending metric.Int64UpDownCounter
|
||||
cleanupRuns metric.Int64Counter
|
||||
peersCleaned metric.Int64Counter
|
||||
errors metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
|
||||
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
|
||||
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EphemeralPeersMetrics{
|
||||
ctx: ctx,
|
||||
pending: pending,
|
||||
cleanupRuns: cleanupRuns,
|
||||
peersCleaned: peersCleaned,
|
||||
errors: errors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// All methods are nil-receiver safe so callers that haven't wired metrics
|
||||
// (tests, self-hosted with metrics off) can invoke them unconditionally.
|
||||
|
||||
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
|
||||
func (m *EphemeralPeersMetrics) IncPending() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// AddPending bumps the pending gauge by n — used at startup when the
|
||||
// initial set of ephemeral peers is loaded from the store.
|
||||
func (m *EphemeralPeersMetrics) AddPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// DecPending decreases the pending gauge — used both when a peer reconnects
|
||||
// before its deadline (removed from the list) and when a cleanup tick
|
||||
// actually deletes it.
|
||||
func (m *EphemeralPeersMetrics) DecPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, -n)
|
||||
}
|
||||
|
||||
// CountCleanupRun records one cleanup pass that processed >0 peers. Idle
|
||||
// ticks (nothing to do) deliberately don't increment so the rate
|
||||
// reflects useful work.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupRun() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.cleanupRuns.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeersCleaned records the number of peers a single tick deleted.
|
||||
func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.peersCleaned.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// CountCleanupError records a failed delete batch.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupError() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.errors.Add(m.ctx, 1)
|
||||
}
|
||||
Reference in New Issue
Block a user