15713: Don't sniff non-error response body.
[arvados.git] / sdk / go / httpserver / responsewriter.go
index d37822ffe3e5cd0f582a59a3ee45b1d322fed4ac..884f0d8d7e2c8a87dbf0cde11135bde24b6f32c0 100644 (file)
@@ -8,10 +8,13 @@ import (
        "net/http"
 )
 
+const sniffBytes = 1024
+
 type ResponseWriter interface {
        http.ResponseWriter
        WroteStatus() int
        WroteBodyBytes() int
+       Sniffed() []byte
 }
 
 // responseWriter wraps http.ResponseWriter and exposes the status
@@ -22,6 +25,7 @@ type responseWriter struct {
        wroteStatus    int   // Last status given to WriteHeader()
        wroteBodyBytes int   // Bytes successfully written
        err            error // Last error returned from Write()
+       sniffed        []byte
 }
 
 func WrapResponseWriter(orig http.ResponseWriter) ResponseWriter {
@@ -41,6 +45,11 @@ func (w *responseWriter) WriteHeader(s int) {
 }
 
 func (w *responseWriter) Write(data []byte) (n int, err error) {
+       if w.wroteStatus == 0 {
+               w.WriteHeader(http.StatusOK)
+       } else if w.wroteStatus >= 400 {
+               w.sniff(data)
+       }
        n, err = w.ResponseWriter.Write(data)
        w.wroteBodyBytes += n
        w.err = err
@@ -58,3 +67,17 @@ func (w *responseWriter) WroteBodyBytes() int {
 func (w *responseWriter) Err() error {
        return w.err
 }
+
+func (w *responseWriter) sniff(data []byte) {
+       max := sniffBytes - len(w.sniffed)
+       if max <= 0 {
+               return
+       } else if max < len(data) {
+               data = data[:max]
+       }
+       w.sniffed = append(w.sniffed, data...)
+}
+
+func (w *responseWriter) Sniffed() []byte {
+       return w.sniffed
+}