mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-06 02:36:41 +00:00
Add holepunch udp server
This commit is contained in:
98
main.go
98
main.go
@@ -57,6 +57,20 @@ var (
|
|||||||
wgClient *wgctrl.Client
|
wgClient *wgctrl.Client
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Add this new type at the top with other type definitions
|
||||||
|
type ClientEndpoint struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
IP string `json:"ip"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchMessage struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
}
|
||||||
|
|
||||||
func parseLogLevel(level string) logger.LogLevel {
|
func parseLogLevel(level string) logger.LogLevel {
|
||||||
switch strings.ToUpper(level) {
|
switch strings.ToUpper(level) {
|
||||||
case "DEBUG":
|
case "DEBUG":
|
||||||
@@ -74,6 +88,82 @@ func parseLogLevel(level string) logger.LogLevel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the startUDPServer function
|
||||||
|
func startUDPServer(addr string, server string) {
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to resolve UDP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to start UDP server: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
logger.Info("UDP server listening on %s", addr)
|
||||||
|
|
||||||
|
buffer := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
n, remoteAddr, err := conn.ReadFromUDP(buffer)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error reading UDP packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg HolePunchMessage
|
||||||
|
if err := json.Unmarshal(buffer[:n], &msg); err != nil {
|
||||||
|
logger.Error("Error unmarshaling message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create endpoint info
|
||||||
|
endpoint := ClientEndpoint{
|
||||||
|
OlmID: msg.OlmID,
|
||||||
|
NewtID: msg.NewtID,
|
||||||
|
IP: remoteAddr.IP.String(),
|
||||||
|
Port: remoteAddr.Port,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the endpoint info to the Olm server
|
||||||
|
go notifyServer(endpoint, server)
|
||||||
|
|
||||||
|
logger.Info("Received hole punch from %s:%d for Olm ID: %s",
|
||||||
|
remoteAddr.IP,
|
||||||
|
remoteAddr.Port,
|
||||||
|
msg.OlmID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add this new function
|
||||||
|
func notifyServer(endpoint ClientEndpoint, server string) {
|
||||||
|
jsonData, err := json.Marshal(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to marshal endpoint data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.Post(server,
|
||||||
|
"application/json",
|
||||||
|
bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to notify Olm server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
logger.Error("Olm server returned non-OK status: %d, body: %s",
|
||||||
|
resp.StatusCode,
|
||||||
|
string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully notified Olm server about endpoint for ID: %s", endpoint.OlmID)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
@@ -85,6 +175,7 @@ func main() {
|
|||||||
reachableAt string
|
reachableAt string
|
||||||
logLevel string
|
logLevel string
|
||||||
mtu string
|
mtu string
|
||||||
|
reportHolePunchTo string
|
||||||
)
|
)
|
||||||
|
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
@@ -96,6 +187,7 @@ func main() {
|
|||||||
reachableAt = os.Getenv("REACHABLE_AT")
|
reachableAt = os.Getenv("REACHABLE_AT")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
mtu = os.Getenv("MTU")
|
mtu = os.Getenv("MTU")
|
||||||
|
reportHolePunchTo = os.Getenv("REPORT_HOLE_PUNCH_TO")
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||||
@@ -112,6 +204,9 @@ func main() {
|
|||||||
if reportBandwidthTo == "" {
|
if reportBandwidthTo == "" {
|
||||||
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
|
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
|
||||||
}
|
}
|
||||||
|
if reportHolePunchTo == "" {
|
||||||
|
flag.StringVar(&reportHolePunchTo, "reportHolePunchTo", "", "Address to listen on")
|
||||||
|
}
|
||||||
if generateAndSaveKeyTo == "" {
|
if generateAndSaveKeyTo == "" {
|
||||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||||
}
|
}
|
||||||
@@ -220,6 +315,9 @@ func main() {
|
|||||||
go periodicBandwidthCheck(reportBandwidthTo)
|
go periodicBandwidthCheck(reportBandwidthTo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run the udp server
|
||||||
|
go startUDPServer(":21820", reportHolePunchTo)
|
||||||
|
|
||||||
http.HandleFunc("/peer", handlePeer)
|
http.HandleFunc("/peer", handlePeer)
|
||||||
logger.Info("Starting server on %s", listenAddr)
|
logger.Info("Starting server on %s", listenAddr)
|
||||||
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))
|
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))
|
||||||
|
|||||||
Reference in New Issue
Block a user