Add an enforce network validation flag for docker to not break previous functionality

This commit is contained in:
Jonny Booker
2025-06-09 23:06:29 +01:00
parent a4d4976103
commit a52260b49d
2 changed files with 60 additions and 21 deletions

View File

@@ -71,7 +71,8 @@ func CheckSocket(socketPath string) bool {
// IsWithinNewtNetwork checks if a provided target is within the newt network // IsWithinNewtNetwork checks if a provided target is within the newt network
func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, targetAddress string, targetPort int) (bool, error) { func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, targetAddress string, targetPort int) (bool, error) {
containers, err := ListContainers(socketPath, containerNameAsHostname) // Always enforce network validation
containers, err := ListContainers(socketPath, true, containerNameAsHostname)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to list Docker containers: %s", err) return false, fmt.Errorf("failed to list Docker containers: %s", err)
} }
@@ -106,7 +107,7 @@ func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, target
} }
// ListContainers lists all Docker containers with their network information // ListContainers lists all Docker containers with their network information
func ListContainers(socketPath string, containerNameAsHostname bool) ([]Container, error) { func ListContainers(socketPath string, enforceNetworkValidation bool, containerNameAsHostname bool) ([]Container, error) {
// Use the provided socket path or default to standard location // Use the provided socket path or default to standard location
if socketPath == "" { if socketPath == "" {
socketPath = "/var/run/docker.sock" socketPath = "/var/run/docker.sock"
@@ -207,7 +208,7 @@ func ListContainers(socketPath string, containerNameAsHostname bool) ([]Containe
} }
// Don't continue returning this container if not in the newt network(s) // Don't continue returning this container if not in the newt network(s)
if !isInNewtNetwork { if enforceNetworkValidation && !isInNewtNetwork {
logger.Debug("container not found within the newt network, skipping: %s", name) logger.Debug("container not found within the newt network, skipping: %s", name)
continue continue
} }

40
main.go
View File

@@ -355,6 +355,8 @@ var (
dockerSocket string dockerSocket string
dockerContainerAsHostname string dockerContainerAsHostname string
dockerContainerAsHostnameBool bool dockerContainerAsHostnameBool bool
dockerEnforceNetworkValidation string
dockerEnforceNetworkValidationBool bool
) )
func main() { func main() {
@@ -369,6 +371,7 @@ func main() {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET") dockerSocket = os.Getenv("DOCKER_SOCKET")
dockerContainerAsHostname = os.Getenv("DOCKER_CONTAINER_NAME_AS_HOSTNAME") dockerContainerAsHostname = os.Getenv("DOCKER_CONTAINER_NAME_AS_HOSTNAME")
dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -400,6 +403,9 @@ func main() {
if dockerContainerAsHostname == "" { if dockerContainerAsHostname == "" {
flag.StringVar(&dockerContainerAsHostname, "docker-container-name-as-hostname", "false", "Use container name when hostname for networking (true or false)") flag.StringVar(&dockerContainerAsHostname, "docker-container-name-as-hostname", "false", "Use container name when hostname for networking (true or false)")
} }
if dockerEnforceNetworkValidation == "" {
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -431,6 +437,13 @@ func main() {
dockerContainerAsHostnameBool = false dockerContainerAsHostnameBool = false
} }
// parse if we want to enforce container network validation
dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation)
if err != nil {
logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'")
dockerEnforceNetworkValidationBool = false
}
privateKey, err = wgtypes.GeneratePrivateKey() privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
@@ -450,6 +463,25 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// output env var values if set
logger.Debug("Endpoint: %v", endpoint)
logger.Debug("Log Level: %v", logLevel)
logger.Debug("Docker Container Name as Hostname: %v", dockerContainerAsHostnameBool)
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "")
if dns != "" {
logger.Debug("Dns: %v", dns)
}
if dockerSocket != "" {
logger.Debug("Docker Socket: %v", dockerSocket)
}
if mtu != "" {
logger.Debug("MTU: %v", mtu)
}
if updownScript != "" {
logger.Debug("Up Down Script: %v", updownScript)
}
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net
@@ -689,7 +721,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
} }
// List Docker containers // List Docker containers
containers, err := docker.ListContainers(dockerSocket, dockerContainerAsHostnameBool) containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool, dockerContainerAsHostnameBool)
if err != nil { if err != nil {
logger.Error("Failed to list Docker containers: %v", err) logger.Error("Failed to list Docker containers: %v", err)
return return
@@ -798,12 +830,18 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
// Add the new target // Add the new target
if dockerEnforceNetworkValidationBool {
logger.Info("Enforcing docker network validation")
isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort) isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort)
if !isWithinNewtNetwork { if !isWithinNewtNetwork {
logger.Error("Not adding target: %v", err) logger.Error("Not adding target: %v", err)
} else { } else {
pm.AddTarget(proto, tunnelIP, port, processedTarget) pm.AddTarget(proto, tunnelIP, port, processedTarget)
} }
} else {
pm.AddTarget(proto, tunnelIP, port, processedTarget)
}
} else if action == "remove" { } else if action == "remove" {
logger.Info("Removing target with port %d", port) logger.Info("Removing target with port %d", port)