Add an enforce network validation flag for docker to not break previous functionality

This commit is contained in:
Jonny Booker
2025-06-09 23:06:29 +01:00
parent a4d4976103
commit a52260b49d
2 changed files with 60 additions and 21 deletions

View File

@@ -71,7 +71,8 @@ func CheckSocket(socketPath string) bool {
// IsWithinNewtNetwork checks if a provided target is within the newt network // IsWithinNewtNetwork checks if a provided target is within the newt network
func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, targetAddress string, targetPort int) (bool, error) { func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, targetAddress string, targetPort int) (bool, error) {
containers, err := ListContainers(socketPath, containerNameAsHostname) // Always enforce network validation
containers, err := ListContainers(socketPath, true, containerNameAsHostname)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to list Docker containers: %s", err) return false, fmt.Errorf("failed to list Docker containers: %s", err)
} }
@@ -106,7 +107,7 @@ func IsWithinNewtNetwork(socketPath string, containerNameAsHostname bool, target
} }
// ListContainers lists all Docker containers with their network information // ListContainers lists all Docker containers with their network information
func ListContainers(socketPath string, containerNameAsHostname bool) ([]Container, error) { func ListContainers(socketPath string, enforceNetworkValidation bool, containerNameAsHostname bool) ([]Container, error) {
// Use the provided socket path or default to standard location // Use the provided socket path or default to standard location
if socketPath == "" { if socketPath == "" {
socketPath = "/var/run/docker.sock" socketPath = "/var/run/docker.sock"
@@ -207,7 +208,7 @@ func ListContainers(socketPath string, containerNameAsHostname bool) ([]Containe
} }
// Don't continue returning this container if not in the newt network(s) // Don't continue returning this container if not in the newt network(s)
if !isInNewtNetwork { if enforceNetworkValidation && !isInNewtNetwork {
logger.Debug("container not found within the newt network, skipping: %s", name) logger.Debug("container not found within the newt network, skipping: %s", name)
continue continue
} }

74
main.go
View File

@@ -341,20 +341,22 @@ func resolveDomain(domain string) (string, error) {
} }
var ( var (
endpoint string endpoint string
id string id string
secret string secret string
mtu string mtu string
mtuInt int mtuInt int
dns string dns string
privateKey wgtypes.Key privateKey wgtypes.Key
err error err error
logLevel string logLevel string
updownScript string updownScript string
tlsPrivateKey string tlsPrivateKey string
dockerSocket string dockerSocket string
dockerContainerAsHostname string dockerContainerAsHostname string
dockerContainerAsHostnameBool bool dockerContainerAsHostnameBool bool
dockerEnforceNetworkValidation string
dockerEnforceNetworkValidationBool bool
) )
func main() { func main() {
@@ -369,6 +371,7 @@ func main() {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET") dockerSocket = os.Getenv("DOCKER_SOCKET")
dockerContainerAsHostname = os.Getenv("DOCKER_CONTAINER_NAME_AS_HOSTNAME") dockerContainerAsHostname = os.Getenv("DOCKER_CONTAINER_NAME_AS_HOSTNAME")
dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -400,6 +403,9 @@ func main() {
if dockerContainerAsHostname == "" { if dockerContainerAsHostname == "" {
flag.StringVar(&dockerContainerAsHostname, "docker-container-name-as-hostname", "false", "Use container name when hostname for networking (true or false)") flag.StringVar(&dockerContainerAsHostname, "docker-container-name-as-hostname", "false", "Use container name when hostname for networking (true or false)")
} }
if dockerEnforceNetworkValidation == "" {
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -431,6 +437,13 @@ func main() {
dockerContainerAsHostnameBool = false dockerContainerAsHostnameBool = false
} }
// parse if we want to enforce container network validation
dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation)
if err != nil {
logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'")
dockerEnforceNetworkValidationBool = false
}
privateKey, err = wgtypes.GeneratePrivateKey() privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
@@ -450,6 +463,25 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// output env var values if set
logger.Debug("Endpoint: %v", endpoint)
logger.Debug("Log Level: %v", logLevel)
logger.Debug("Docker Container Name as Hostname: %v", dockerContainerAsHostnameBool)
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "")
if dns != "" {
logger.Debug("Dns: %v", dns)
}
if dockerSocket != "" {
logger.Debug("Docker Socket: %v", dockerSocket)
}
if mtu != "" {
logger.Debug("MTU: %v", mtu)
}
if updownScript != "" {
logger.Debug("Up Down Script: %v", updownScript)
}
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net
@@ -689,7 +721,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
} }
// List Docker containers // List Docker containers
containers, err := docker.ListContainers(dockerSocket, dockerContainerAsHostnameBool) containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool, dockerContainerAsHostnameBool)
if err != nil { if err != nil {
logger.Error("Failed to list Docker containers: %v", err) logger.Error("Failed to list Docker containers: %v", err)
return return
@@ -798,9 +830,15 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
// Add the new target // Add the new target
isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort) if dockerEnforceNetworkValidationBool {
if !isWithinNewtNetwork { logger.Info("Enforcing docker network validation")
logger.Error("Not adding target: %v", err)
isWithinNewtNetwork, err := docker.IsWithinNewtNetwork(dockerSocket, dockerContainerAsHostnameBool, targetAddress, targetPort)
if !isWithinNewtNetwork {
logger.Error("Not adding target: %v", err)
} else {
pm.AddTarget(proto, tunnelIP, port, processedTarget)
}
} else { } else {
pm.AddTarget(proto, tunnelIP, port, processedTarget) pm.AddTarget(proto, tunnelIP, port, processedTarget)
} }