Merge branch '18323-cwl-gpu2' refs #18323
[arvados.git] / sdk / go / httpserver / responsewriter.go
1 // Copyright (C) The Arvados Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0
4
5 package httpserver
6
7 import (
8         "net/http"
9 )
10
11 const sniffBytes = 1024
12
13 type ResponseWriter interface {
14         http.ResponseWriter
15         WroteStatus() int
16         WroteBodyBytes() int
17         Sniffed() []byte
18 }
19
20 // responseWriter wraps http.ResponseWriter and exposes the status
21 // sent, the number of bytes sent to the client, and the last write
22 // error.
23 type responseWriter struct {
24         http.ResponseWriter
25         wroteStatus    int   // First status given to WriteHeader()
26         wroteBodyBytes int   // Bytes successfully written
27         err            error // Last error returned from Write()
28         sniffed        []byte
29 }
30
31 func WrapResponseWriter(orig http.ResponseWriter) ResponseWriter {
32         return &responseWriter{ResponseWriter: orig}
33 }
34
35 func (w *responseWriter) CloseNotify() <-chan bool {
36         if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
37                 return cn.CloseNotify()
38         }
39         return nil
40 }
41
42 func (w *responseWriter) WriteHeader(s int) {
43         if w.wroteStatus == 0 {
44                 w.wroteStatus = s
45         }
46         // ...else it's too late to change the status seen by the
47         // client -- but we call the wrapped WriteHeader() anyway so
48         // it can log a warning.
49         w.ResponseWriter.WriteHeader(s)
50 }
51
52 func (w *responseWriter) Write(data []byte) (n int, err error) {
53         if w.wroteStatus == 0 {
54                 w.WriteHeader(http.StatusOK)
55         } else if w.wroteStatus >= 400 {
56                 w.sniff(data)
57         }
58         n, err = w.ResponseWriter.Write(data)
59         w.wroteBodyBytes += n
60         w.err = err
61         return
62 }
63
64 func (w *responseWriter) WroteStatus() int {
65         return w.wroteStatus
66 }
67
68 func (w *responseWriter) WroteBodyBytes() int {
69         return w.wroteBodyBytes
70 }
71
72 func (w *responseWriter) Err() error {
73         return w.err
74 }
75
76 func (w *responseWriter) sniff(data []byte) {
77         max := sniffBytes - len(w.sniffed)
78         if max <= 0 {
79                 return
80         } else if max < len(data) {
81                 data = data[:max]
82         }
83         w.sniffed = append(w.sniffed, data...)
84 }
85
86 func (w *responseWriter) Sniffed() []byte {
87         return w.sniffed
88 }