diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 97e4662fd..c37740587 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -245,33 +245,29 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu defer bufPool.Put(bufp) buffer := *bufp - if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) - } - if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set write deadline: %w", err) - } - for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return ctx.Err() - default: - n, err := src.Read(buffer) - if err != nil { - if isTimeout(err) { - continue - } - return fmt.Errorf("read from %s: %w", direction, err) - } - - _, err = dst.Write(buffer[:n]) - if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) - } - - c.updateLastSeen() } + + if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + + n, err := src.Read(buffer) + if err != nil { + if isTimeout(err) { + continue + } + return fmt.Errorf("read from %s: %w", direction, err) + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + return fmt.Errorf("write to %s: %w", direction, err) + } + + c.updateLastSeen() } }