From a52260b49d7b10a2656b992c224aed13ae15159e Mon Sep 17 00:00:00 2001 From: Jonny Booker <1131478+JonnyBooker@users.noreply.github.com> Date: Mon, 9 Jun 2025 23:06:29 +0100 Subject: [PATCH] Add an enforce network validation flag for docker to not break previous functionality --- docker/client.go | 7 +++-- main.go | 74 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/docker/client.go b/docker/client.go index 762ce28..aee7fdf 100644 --- a/docker/client.go +++ b/docker/client.go @@ -71,7 +71,8 @@ func CheckSocket(socketPath string) bool { // IsWithinNewtNetwork checks if a provided target is within the newt network func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, targetAddress string, targetPort int) (bool, error) { - containers, err := ListContainers(socketPath, containerNameAsHostname) + // Always enforce network validation + containers, err := ListContainers(socketPath, true, containerNameAsHostname) if err != nil { return false, fmt.Errorf("failed to list Docker containers: %s", err) } @@ -106,7 +107,7 @@ func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, target } // ListContainers lists all Docker containers with their network information -func ListContainers(socketPath string, containerNameAsHostname bool) ([]Container, error) { +func ListContainers(socketPath string, enforceNetworkValidation bool, containerNameAsHostname bool) ([]Container, error) { // Use the provided socket path or default to standard location if socketPath == "" { socketPath = "/var/run/docker.sock" @@ -207,7 +208,7 @@ func ListContainers(socketPath string, containerNameAsHostname bool) ([]Containe } // Don't continue returning this container if not in the newt network(s) - if !isInNewtNetwork { + if enforceNetworkValidation && !isInNewtNetwork { logger.Debug("container not found within the newt network, skipping: %s", name) continue } diff --git a/main.go b/main.go index 7cfb002..177f32e 100644 --- a/main.go +++ b/main.go @@ -341,20 +341,22 @@ func resolveDomain(domain string) (string, error) { } var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - updownScript string - tlsPrivateKey string - dockerSocket string - dockerContainerAsHostname string - dockerContainerAsHostnameBool bool + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + updownScript string + tlsPrivateKey string + dockerSocket string + dockerContainerAsHostname string + dockerContainerAsHostnameBool bool + dockerEnforceNetworkValidation string + dockerEnforceNetworkValidationBool bool ) func main() { @@ -369,6 +371,7 @@ func main() { tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") dockerContainerAsHostname = os.Getenv("DOCKER_CONTAINER_NAME_AS_HOSTNAME") + dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -400,6 +403,9 @@ func main() { if dockerContainerAsHostname == "" { flag.StringVar(&dockerContainerAsHostname, "docker-container-name-as-hostname", "false", "Use container name when hostname for networking (true or false)") } + if dockerEnforceNetworkValidation == "" { + flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -431,6 +437,13 @@ func main() { dockerContainerAsHostnameBool = false } + // parse if we want to enforce container network validation + dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation) + if err != nil { + logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'") + dockerEnforceNetworkValidationBool = false + } + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) @@ -450,6 +463,25 @@ func main() { logger.Fatal("Failed to create client: %v", err) } + // output env var values if set + logger.Debug("Endpoint: %v", endpoint) + logger.Debug("Log Level: %v", logLevel) + logger.Debug("Docker Container Name as Hostname: %v", dockerContainerAsHostnameBool) + logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool) + logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "") + if dns != "" { + logger.Debug("Dns: %v", dns) + } + if dockerSocket != "" { + logger.Debug("Docker Socket: %v", dockerSocket) + } + if mtu != "" { + logger.Debug("MTU: %v", mtu) + } + if updownScript != "" { + logger.Debug("Up Down Script: %v", updownScript) + } + // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -689,7 +721,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } // List Docker containers - containers, err := docker.ListContainers(dockerSocket, dockerContainerAsHostnameBool) + containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool, dockerContainerAsHostnameBool) if err != nil { logger.Error("Failed to list Docker containers: %v", err) return @@ -798,9 +830,15 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } // Add the new target - isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort) - if !isWithinNewtNetwork { - logger.Error("Not adding target: %v", err) + if dockerEnforceNetworkValidationBool { + logger.Info("Enforcing docker network validation") + + isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort) + if !isWithinNewtNetwork { + logger.Error("Not adding target: %v", err) + } else { + pm.AddTarget(proto, tunnelIP, port, processedTarget) + } } else { pm.AddTarget(proto, tunnelIP, port, processedTarget) }