blob: 6356b32872a45cb30df342dac0341e2cfcde25c9 [file] [log] [blame]
Serge Bazanskicc25bdf2018-10-25 14:02:58 +02001// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build go1.6
6
7package http2
8
9import (
10 "crypto/tls"
11 "fmt"
12 "net/http"
13)
14
15func configureTransport(t1 *http.Transport) (*Transport, error) {
16 connPool := new(clientConnPool)
17 t2 := &Transport{
18 ConnPool: noDialClientConnPool{connPool},
19 t1: t1,
20 }
21 connPool.t = t2
22 if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
23 return nil, err
24 }
25 if t1.TLSClientConfig == nil {
26 t1.TLSClientConfig = new(tls.Config)
27 }
28 if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
29 t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
30 }
31 if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
32 t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
33 }
34 upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
35 addr := authorityAddr("https", authority)
36 if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
37 go c.Close()
38 return erringRoundTripper{err}
39 } else if !used {
40 // Turns out we don't need this c.
41 // For example, two goroutines made requests to the same host
42 // at the same time, both kicking off TCP dials. (since protocol
43 // was unknown)
44 go c.Close()
45 }
46 return t2
47 }
48 if m := t1.TLSNextProto; len(m) == 0 {
49 t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
50 "h2": upgradeFn,
51 }
52 } else {
53 m["h2"] = upgradeFn
54 }
55 return t2, nil
56}
57
58// registerHTTPSProtocol calls Transport.RegisterProtocol but
59// converting panics into errors.
60func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err error) {
61 defer func() {
62 if e := recover(); e != nil {
63 err = fmt.Errorf("%v", e)
64 }
65 }()
66 t.RegisterProtocol("https", rt)
67 return nil
68}
69
70// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
71// if there's already has a cached connection to the host.
72// (The field is exported so it can be accessed via reflect from net/http; tested
73// by TestNoDialH2RoundTripperType)
74type noDialH2RoundTripper struct{ *Transport }
75
76func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
77 res, err := rt.Transport.RoundTrip(req)
78 if isNoCachedConnError(err) {
79 return nil, http.ErrSkipAltProtocol
80 }
81 return res, err
82}