"github.com/sirupsen/logrus"
)
-func (sshconn ContainerSSHConnection) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+func (cresp ConnectionResponse) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ defer cresp.Conn.Close()
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "ResponseWriter does not support connection upgrade", http.StatusInternalServerError)
return
}
w.Header().Set("Connection", "upgrade")
- w.Header().Set("Upgrade", "ssh")
+ for k, v := range cresp.Header {
+ w.Header()[k] = v
+ }
w.WriteHeader(http.StatusSwitchingProtocols)
conn, bufrw, err := hj.Hijack()
if err != nil {
var bytesIn, bytesOut int64
var wg sync.WaitGroup
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := context.WithCancel(req.Context())
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
- n, err := io.CopyN(conn, sshconn.Bufrw, int64(sshconn.Bufrw.Reader.Buffered()))
+ n, err := io.CopyN(conn, cresp.Bufrw, int64(cresp.Bufrw.Reader.Buffered()))
bytesOut += n
if err == nil {
- n, err = io.Copy(conn, sshconn.Conn)
+ n, err = io.Copy(conn, cresp.Conn)
bytesOut += n
}
if err != nil {
- ctxlog.FromContext(req.Context()).WithError(err).Error("error copying downstream")
+ ctxlog.FromContext(ctx).WithError(err).Error("error copying downstream")
}
+ conn.Close()
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
- n, err := io.CopyN(sshconn.Conn, bufrw, int64(bufrw.Reader.Buffered()))
+ n, err := io.CopyN(cresp.Conn, bufrw, int64(bufrw.Reader.Buffered()))
bytesIn += n
if err == nil {
- n, err = io.Copy(sshconn.Conn, conn)
+ n, err = io.Copy(cresp.Conn, conn)
bytesIn += n
}
if err != nil {
- ctxlog.FromContext(req.Context()).WithError(err).Error("error copying upstream")
+ ctxlog.FromContext(ctx).WithError(err).Error("error copying upstream")
}
+ cresp.Conn.Close()
}()
- <-ctx.Done()
- if sshconn.Logger != nil {
- go func() {
- wg.Wait()
- sshconn.Logger.WithFields(logrus.Fields{
- "bytesIn": bytesIn,
- "bytesOut": bytesOut,
- }).Info("closed connection")
- }()
+ wg.Wait()
+ if cresp.Logger != nil {
+ cresp.Logger.WithFields(logrus.Fields{
+ "bytesIn": bytesIn,
+ "bytesOut": bytesOut,
+ }).Info("closed connection")
}
}