diff --git a/relay/relay.go b/relay/relay.go index 5baf779..eb9261a 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -1,6 +1,7 @@ package relay import ( + "bufio" "bytes" "context" "encoding/binary" @@ -210,10 +211,15 @@ func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privat // Start sets up the UDP listener, worker pool, and begins reading packets. func (s *UDPProxyServer) Start() error { - // Fetch initial mappings. - if err := s.fetchInitialMappings(); err != nil { - return fmt.Errorf("failed to fetch initial mappings: %v", err) - } + // Fetch initial mappings asynchronously so a large (potentially 100MB+) + // response does not block the UDP listener from coming up. Any packets + // arriving for unknown mappings before the load completes will simply + // log and be repopulated via the hole-punch path. + go func() { + if err := s.fetchInitialMappings(); err != nil { + logger.Error("Failed to fetch initial mappings: %v", err) + } + }() udpAddr, err := net.ResolveUDPAddr("udp", s.addr) if err != nil { @@ -488,6 +494,7 @@ func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byt } func (s *UDPProxyServer) fetchInitialMappings() error { + logger.Info("Requesting initial proxy mappings") body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String()))) resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body) if err != nil { @@ -499,24 +506,82 @@ func (s *UDPProxyServer) fetchInitialMappings() error { return fmt.Errorf("server returned non-OK status: %d, body: %s", resp.StatusCode, string(body)) } - data, err := io.ReadAll(resp.Body) + logger.Info("Received initial mappings, streaming decode") + + // Stream-decode the response instead of buffering the entire body + // (which can be 100MB+) and then re-walking it with json.Unmarshal. + // This both lowers peak memory and lets us start populating the + // sync.Map as entries arrive. + dec := json.NewDecoder(bufio.NewReaderSize(resp.Body, 1<<20)) + + // Expect opening '{' of the top-level object. + tok, err := dec.Token() if err != nil { - return fmt.Errorf("failed to read response body: %v", err) + return fmt.Errorf("failed to read opening token: %v", err) } - logger.Info("Received initial mappings: %s", string(data)) - var initialMappings InitialMappings - if err := json.Unmarshal(data, &initialMappings); err != nil { - return fmt.Errorf("failed to unmarshal initial mappings: %v", err) + if d, ok := tok.(json.Delim); !ok || d != '{' { + return fmt.Errorf("expected '{' at top level, got %v", tok) } - // Store mappings in our sync.Map. - for key, mapping := range initialMappings.Mappings { - // Initialize LastUsed timestamp for initial mappings - mapping.LastUsed = time.Now() - s.proxyMappings.Store(key, mapping) + + count := 0 + now := time.Now() + + for dec.More() { + keyTok, err := dec.Token() + if err != nil { + return fmt.Errorf("failed to read top-level key: %v", err) + } + key, ok := keyTok.(string) + if !ok { + return fmt.Errorf("expected string key at top level, got %T", keyTok) + } + + if key != "mappings" { + // Skip unknown top-level fields without materializing them. + var skip json.RawMessage + if err := dec.Decode(&skip); err != nil { + return fmt.Errorf("failed to skip field %q: %v", key, err) + } + continue + } + + // Expect opening '{' of the mappings object. + tok, err := dec.Token() + if err != nil { + return fmt.Errorf("failed to read mappings open: %v", err) + } + if d, ok := tok.(json.Delim); !ok || d != '{' { + return fmt.Errorf("expected '{' for mappings, got %v", tok) + } + + for dec.More() { + mapKeyTok, err := dec.Token() + if err != nil { + return fmt.Errorf("failed to read mapping key: %v", err) + } + mapKey, ok := mapKeyTok.(string) + if !ok { + return fmt.Errorf("expected string mapping key, got %T", mapKeyTok) + } + + var mapping ProxyMapping + if err := dec.Decode(&mapping); err != nil { + return fmt.Errorf("failed to decode mapping %q: %v", mapKey, err) + } + mapping.LastUsed = now + s.proxyMappings.Store(mapKey, mapping) + count++ + } + + // Consume closing '}' of mappings object. + if _, err := dec.Token(); err != nil { + return fmt.Errorf("failed to read mappings close: %v", err) + } } - metrics.RecordProxyInitialMappings(relayIfname, int64(len(initialMappings.Mappings))) - metrics.RecordProxyMapping(relayIfname, int64(len(initialMappings.Mappings))) - logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) + + metrics.RecordProxyInitialMappings(relayIfname, int64(count)) + metrics.RecordProxyMapping(relayIfname, int64(count)) + logger.Info("Loaded %d initial proxy mappings", count) return nil }