diff --git a/.gitignore b/.gitignore index e69de29..6dd29b7 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1 @@ +bin/ \ No newline at end of file diff --git a/go.mod b/go.mod index 421d3be..2e3b644 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 require ( github.com/google/btree v1.1.2 // indirect + github.com/gorilla/websocket v1.5.3 // indirect golang.org/x/crypto v0.28.0 // indirect golang.org/x/net v0.30.0 // indirect golang.org/x/sys v0.26.0 // indirect diff --git a/go.sum b/go.sum index 3bf4a5d..682c547 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= diff --git a/main.go b/main.go index 98be895..a522df1 100644 --- a/main.go +++ b/main.go @@ -6,15 +6,13 @@ import ( "encoding/hex" "flag" "fmt" - "io" "log" "math/rand" - "net" "net/netip" + "newt/proxy" "os" "os/signal" "strings" - "sync" "syscall" "time" @@ -25,166 +23,6 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" ) -type ProxyTarget struct { - Protocol string - Listen string - Targets []string -} - -type ProxyManager struct { - targets []ProxyTarget - tnet *netstack.Net -} - -func NewProxyManager(tnet *netstack.Net) *ProxyManager { - return &ProxyManager{ - tnet: tnet, - } -} - -func (pm *ProxyManager) AddTarget(protocol, listen string, targets []string) { - pm.targets = append(pm.targets, ProxyTarget{ - Protocol: protocol, - Listen: listen, - Targets: targets, - }) -} - -func (pm *ProxyManager) Start() error { - for _, target := range pm.targets { - switch strings.ToLower(target.Protocol) { - case "tcp": - go pm.serveTCP(target) - case "udp": - go pm.serveUDP(target) - default: - return fmt.Errorf("unsupported protocol: %s", target.Protocol) - } - } - return nil -} - -func (pm *ProxyManager) serveTCP(target ProxyTarget) { - listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ - IP: net.ParseIP(target.Listen), - Port: 0, - }) - if err != nil { - log.Printf("Failed to start TCP listener for %s: %v", target.Listen, err) - return - } - defer listener.Close() - - log.Printf("TCP proxy listening on %s", listener.Addr()) - - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Failed to accept TCP connection: %v", err) - continue - } - - go pm.handleTCPConnection(conn, target.Targets) - } -} - -func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, targets []string) { - defer clientConn.Close() - - // Round-robin through targets - targetIndex := 0 - target := targets[targetIndex] - targetIndex = (targetIndex + 1) % len(targets) - - serverConn, err := net.Dial("tcp", target) - if err != nil { - log.Printf("Failed to connect to target %s: %v", target, err) - return - } - defer serverConn.Close() - - var wg sync.WaitGroup - wg.Add(2) - - // Client -> Server - go func() { - defer wg.Done() - io.Copy(serverConn, clientConn) - }() - - // Server -> Client - go func() { - defer wg.Done() - io.Copy(clientConn, serverConn) - }() - - wg.Wait() -} - -func (pm *ProxyManager) serveUDP(target ProxyTarget) { - addr := &net.UDPAddr{ - IP: net.ParseIP(target.Listen), - Port: 0, - } - - conn, err := pm.tnet.ListenUDP(addr) - if err != nil { - log.Printf("Failed to start UDP listener for %s: %v", target.Listen, err) - return - } - defer conn.Close() - - log.Printf("UDP proxy listening on %s", conn.LocalAddr()) - - buffer := make([]byte, 65535) - targetIndex := 0 - - for { - // Read from the UDP connection - n, remoteAddr, err := conn.ReadFrom(buffer) - if err != nil { - log.Printf("Failed to read UDP packet: %v", err) - continue - } - - t := target.Targets[targetIndex] - targetIndex = (targetIndex + 1) % len(target.Targets) - - targetAddr, err := net.ResolveUDPAddr("udp", t) - if err != nil { - log.Printf("Failed to resolve target address %s: %v", target, err) - continue - } - - go func(data []byte, remote net.Addr) { - targetConn, err := net.DialUDP("udp", nil, targetAddr) - if err != nil { - log.Printf("Failed to connect to target %s: %v", target, err) - return - } - defer targetConn.Close() - - _, err = targetConn.Write(data) - if err != nil { - log.Printf("Failed to write to target: %v", err) - return - } - - response := make([]byte, 65535) - n, err := targetConn.Read(response) - if err != nil { - log.Printf("Failed to read response from target: %v", err) - return - } - - _, err = conn.WriteTo(response[:n], remote) - if err != nil { - log.Printf("Failed to write response to client: %v", err) - } - }(buffer[:n], remoteAddr) - } -} - func fixKey(key string) string { // Remove any whitespace key = strings.TrimSpace(key) @@ -293,18 +131,46 @@ persistent_keepalive_interval=5 ping(tnet, serverIP) // Create proxy manager - pm := NewProxyManager(tnet) + pm := proxy.NewProxyManager(tnet) // Add TCP targets if tcpTargets != "" { targets := strings.Split(tcpTargets, ",") - pm.AddTarget("tcp", listenIP, targets) + for _, t := range targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 2 { + log.Panicf("Invalid target: %s", t) + } + // get the port as a int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + log.Panicf("Invalid port: %s", parts[0]) + } + target := parts[1] + pm.AddTarget("tcp", listenIP, port, target) + } } // Add UDP targets if udpTargets != "" { targets := strings.Split(udpTargets, ",") - pm.AddTarget("udp", listenIP, targets) + for _, t := range targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 2 { + log.Panicf("Invalid target: %s", t) + } + // get the port as a int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + log.Panicf("Invalid port: %s", parts[0]) + } + target := parts[1] + pm.AddTarget("udp", listenIP, port, target) + } } // Start proxies @@ -313,6 +179,13 @@ persistent_keepalive_interval=5 log.Panic(err) } + url := "ws://localhost/api/v1/ws" + token := "your-auth-token" + + if err := websocket.connectWebSocket(url, token); err != nil { + log.Fatalf("WebSocket error: %v", err) + } + // Wait for interrupt signal sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) diff --git a/newt b/newt deleted file mode 100755 index 0ce5c93..0000000 Binary files a/newt and /dev/null differ diff --git a/proxy/manager.go b/proxy/manager.go new file mode 100644 index 0000000..65298e6 --- /dev/null +++ b/proxy/manager.go @@ -0,0 +1,246 @@ +package proxy + +import ( + "fmt" + "io" + "log" + "net" + "strings" + "sync" + + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func NewProxyManager(tnet *netstack.Net) *ProxyManager { + return &ProxyManager{ + tnet: tnet, + } +} + +func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) { + pm.Lock() + defer pm.Unlock() + + newTarget := ProxyTarget{ + Protocol: protocol, + Listen: listen, + Port: port, + Target: target, + cancel: make(chan struct{}), + } + + pm.targets = append(pm.targets, newTarget) +} + +func (pm *ProxyManager) RemoveTarget(listen string, port int) error { + pm.Lock() + defer pm.Unlock() + + for i, target := range pm.targets { + if target.Listen == listen && target.Port == port { + // Signal the serving goroutine to stop + close(target.cancel) + + // Close the listener/connection + target.Lock() + if target.listener != nil { + target.listener.Close() + } + if target.udpConn != nil { + target.udpConn.Close() + } + target.Unlock() + + // Remove the target from the slice + pm.targets = append(pm.targets[:i], pm.targets[i+1:]...) + return nil + } + } + + return fmt.Errorf("target not found for %s:%d", listen, port) +} + +func (pm *ProxyManager) Start() error { + pm.RLock() + defer pm.RUnlock() + + for i := range pm.targets { + target := &pm.targets[i] // Use pointer to modify the target in the slice + switch strings.ToLower(target.Protocol) { + case "tcp": + go pm.serveTCP(target) + case "udp": + go pm.serveUDP(target) + default: + return fmt.Errorf("unsupported protocol: %s", target.Protocol) + } + } + return nil +} + +func (pm *ProxyManager) serveTCP(target *ProxyTarget) { + listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ + IP: net.ParseIP(target.Listen), + Port: target.Port, + }) + log.Printf("Listening on %s:%d", target.Listen, target.Port) + if err != nil { + log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err) + return + } + + target.Lock() + target.listener = listener + target.Unlock() + + defer listener.Close() + log.Printf("TCP proxy listening on %s", listener.Addr()) + + // Channel to signal active connections to close + done := make(chan struct{}) + var activeConns sync.WaitGroup + + // Goroutine to handle shutdown signal + go func() { + <-target.cancel + close(done) + listener.Close() + }() + + for { + conn, err := listener.Accept() + if err != nil { + select { + case <-target.cancel: + // Wait for active connections to finish + activeConns.Wait() + return + default: + log.Printf("Failed to accept TCP connection: %v", err) + continue + } + } + + activeConns.Add(1) + go func() { + defer activeConns.Done() + pm.handleTCPConnection(conn, target.Target, done) + }() + } +} + +func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) { + defer clientConn.Close() + + serverConn, err := net.Dial("tcp", target) + if err != nil { + log.Printf("Failed to connect to target %s: %v", target, err) + return + } + defer serverConn.Close() + + var wg sync.WaitGroup + wg.Add(2) + + // Client -> Server + go func() { + defer wg.Done() + select { + case <-done: + return + default: + io.Copy(serverConn, clientConn) + } + }() + + // Server -> Client + go func() { + defer wg.Done() + select { + case <-done: + return + default: + io.Copy(clientConn, serverConn) + } + }() + + wg.Wait() +} + +func (pm *ProxyManager) serveUDP(target *ProxyTarget) { + addr := &net.UDPAddr{ + IP: net.ParseIP(target.Listen), + Port: target.Port, + } + + conn, err := pm.tnet.ListenUDP(addr) + if err != nil { + log.Printf("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err) + return + } + + target.Lock() + target.udpConn = conn + target.Unlock() + + defer conn.Close() + log.Printf("UDP proxy listening on %s", conn.LocalAddr()) + + buffer := make([]byte, 65535) + + for { + select { + case <-target.cancel: + return + default: + n, remoteAddr, err := conn.ReadFrom(buffer) + if err != nil { + select { + case <-target.cancel: + return + default: + log.Printf("Failed to read UDP packet: %v", err) + continue + } + } + + targetAddr, err := net.ResolveUDPAddr("udp", target.Target) + if err != nil { + log.Printf("Failed to resolve target address %s: %v", target.Target, err) + continue + } + + go func(data []byte, remote net.Addr) { + targetConn, err := net.DialUDP("udp", nil, targetAddr) + if err != nil { + log.Printf("Failed to connect to target %s: %v", target.Target, err) + return + } + defer targetConn.Close() + + select { + case <-target.cancel: + return + default: + _, err = targetConn.Write(data) + if err != nil { + log.Printf("Failed to write to target: %v", err) + return + } + + response := make([]byte, 65535) + n, err := targetConn.Read(response) + if err != nil { + log.Printf("Failed to read response from target: %v", err) + return + } + + _, err = conn.WriteTo(response[:n], remote) + if err != nil { + log.Printf("Failed to write response to client: %v", err) + } + } + }(buffer[:n], remoteAddr) + } + } +} diff --git a/proxy/types.go b/proxy/types.go new file mode 100644 index 0000000..f1e334f --- /dev/null +++ b/proxy/types.go @@ -0,0 +1,25 @@ +package proxy + +import ( + "net" + "sync" + + "golang.zx2c4.com/wireguard/tun/netstack" +) + +type ProxyTarget struct { + Protocol string + Listen string + Port int + Target string + cancel chan struct{} // Channel to signal shutdown + listener net.Listener // For TCP + udpConn net.PacketConn // For UDP + sync.Mutex // Protect access to connections +} + +type ProxyManager struct { + targets []ProxyTarget + tnet *netstack.Net + sync.RWMutex // Protect access to targets slice +} diff --git a/test/newt_client.sh b/test/newt_client.sh index 46d9419..211aa15 100644 --- a/test/newt_client.sh +++ b/test/newt_client.sh @@ -3,7 +3,7 @@ "--private-key=kAexrEV1OHlMYQU3BZatZxNfKGAbzo+ATspAdtOcRks=" \ "--public-key=Kn4eD0kvcTwjO//zqH/CtNVkMNdMiUkbqFxysEym2D8=" \ --endpoint=192.168.1.16:51820 \ - --tcp-targets=127.0.0.1:8080 \ - --udp-targets=127.0.0.1:53 \ + --tcp-targets=9999:127.0.0.1:8080 \ + --udp-targets=9953:127.0.0.1:53 \ --listen-ip=192.168.4.28 \ --server-ip=192.168.4.1 \ No newline at end of file diff --git a/websocket/manager.go b/websocket/manager.go new file mode 100644 index 0000000..9308295 --- /dev/null +++ b/websocket/manager.go @@ -0,0 +1,60 @@ +package websocket + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + + "github.com/gorilla/websocket" +) + +func connectWebSocket(url, token string) error { + // Create custom header with the auth token + header := http.Header{} + header.Add("Sec-WebSocket-Protocol", token) + + // Create dialer with default options + dialer := websocket.Dialer{ + EnableCompression: true, + } + + // Connect to WebSocket server + conn, resp, err := dialer.Dial(url, header) + if err != nil { + log.Printf("Dial failed: %v", err) + if resp != nil { + log.Printf("HTTP Response Status: %s", resp.Status) + } + return err + } + defer conn.Close() + + log.Printf("Connected to WebSocket server") + + // Message handling loop + for { + // Read message + messageType, message, err := conn.ReadMessage() + if err != nil { + log.Printf("Read error: %v", err) + return err + } + + // Handle text messages (JSON expected) + if messageType == websocket.TextMessage { + // Create a map to store the JSON data + var jsonData map[string]interface{} + + // Unmarshal the JSON message + if err := json.Unmarshal(message, &jsonData); err != nil { + log.Printf("JSON parsing error: %v", err) + // Continue reading messages even if one fails to parse + continue + } + + // Print the parsed JSON message + fmt.Printf("Received message: %+v\n", jsonData) + } + } +}