11package websocket
22
33import (
4- "fmt"
4+ "crypto/sha1"
5+ "encoding/base64"
56 "net/http"
7+ "net/url"
8+ "strings"
9+
10+ "golang.org/x/net/http/httpguts"
11+ "golang.org/x/xerrors"
612)
713
814// AcceptOption is an option that can be passed to Accept.
15+ // The implementations of this interface are printable.
916type AcceptOption interface {
1017 acceptOption ()
11- fmt.Stringer
1218}
1319
20+ type acceptSubprotocols []string
21+
22+ func (o acceptSubprotocols ) acceptOption () {}
23+
1424// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
1525// The first protocol that a client supports will be negotiated.
16- // Pass "" as a subprotocol if you would like to allow the default protocol.
26+ // Pass "" as a subprotocol if you would like to allow the default protocol along with
27+ // specific subprotocols.
1728func AcceptSubprotocols (subprotocols ... string ) AcceptOption {
18- panic ( "TODO" )
29+ return acceptSubprotocols ( subprotocols )
1930}
2031
32+ type acceptOrigins []string
33+
34+ func (o acceptOrigins ) acceptOption () {}
35+
2136// AcceptOrigins lists the origins that Accept will accept.
2237// Accept will always accept r.Host as the origin so you do not need to
2338// specify that with this option.
2439//
2540// Use this option with caution to avoid exposing your WebSocket
2641// server to a CSRF attack.
2742// See https://stackoverflow.com/a/37837709/4283659
28- // You can use a * to specify wildcards.
43+ // You can use a * for wildcards.
2944func AcceptOrigins (origins ... string ) AcceptOption {
30- panic ( "TODO" )
45+ return AcceptOrigins ( origins ... )
3146}
3247
3348// Accept accepts a WebSocket handshake from a client and upgrades the
@@ -36,5 +51,121 @@ func AcceptOrigins(origins ...string) AcceptOption {
3651// InsecureAcceptOrigin is passed.
3752// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
3853func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
39- panic ("TODO" )
54+ var subprotocols []string
55+ origins := []string {r .Host }
56+ for _ , opt := range opts {
57+ switch opt := opt .(type ) {
58+ case acceptOrigins :
59+ origins = []string (opt )
60+ case acceptSubprotocols :
61+ subprotocols = []string (opt )
62+ }
63+ }
64+
65+ if ! httpguts .HeaderValuesContainsToken (r .Header ["Connection" ], "Upgrade" ) {
66+ err := xerrors .Errorf ("websocket: protocol violation: Connection header does not contain Upgrade: %q" , r .Header .Get ("Connection" ))
67+ http .Error (w , err .Error (), http .StatusBadRequest )
68+ return nil , err
69+ }
70+
71+ if ! httpguts .HeaderValuesContainsToken (r .Header ["Upgrade" ], "websocket" ) {
72+ err := xerrors .Errorf ("websocket: protocol violation: Upgrade header does not contain websocket: %q" , r .Header .Get ("Upgrade" ))
73+ http .Error (w , err .Error (), http .StatusBadRequest )
74+ return nil , err
75+ }
76+
77+ if r .Method != "GET" {
78+ err := xerrors .Errorf ("websocket: protocol violation: handshake request method is not GET: %q" , r .Method )
79+ http .Error (w , err .Error (), http .StatusBadRequest )
80+ return nil , err
81+ }
82+
83+ if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
84+ err := xerrors .Errorf ("websocket: unsupported protocol version: %q" , r .Header .Get ("Sec-WebSocket-Version" ))
85+ http .Error (w , err .Error (), http .StatusBadRequest )
86+ return nil , err
87+ }
88+
89+ if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
90+ err := xerrors .New ("websocket: protocol violation: missing Sec-WebSocket-Key" )
91+ http .Error (w , err .Error (), http .StatusBadRequest )
92+ return nil , err
93+ }
94+
95+ origins = append (origins , r .Host )
96+
97+ err := authenticateOrigin (r , origins )
98+ if err != nil {
99+ http .Error (w , err .Error (), http .StatusForbidden )
100+ return nil , err
101+ }
102+
103+ hj , ok := w .(http.Hijacker )
104+ if ! ok {
105+ err = xerrors .New ("websocket: response writer does not implement http.Hijacker" )
106+ http .Error (w , err .Error (), http .StatusInternalServerError )
107+ return nil , err
108+ }
109+
110+ w .Header ().Set ("Upgrade" , "websocket" )
111+ w .Header ().Set ("Connection" , "Upgrade" )
112+
113+ handleKey (w , r )
114+
115+ selectSubprotocol (w , r , subprotocols )
116+
117+ w .WriteHeader (http .StatusSwitchingProtocols )
118+
119+ c , brw , err := hj .Hijack ()
120+ if err != nil {
121+ err = xerrors .Errorf ("websocket: failed to hijack connection: %v" , err )
122+ http .Error (w , err .Error (), http .StatusInternalServerError )
123+ return nil , err
124+ }
125+
126+ _ = c
127+ _ = brw
128+
129+ return nil , nil
130+ }
131+
132+ func selectSubprotocol (w http.ResponseWriter , r * http.Request , subprotocols []string ) {
133+ clientSubprotocols := strings .Split (r .Header .Get ("Sec-WebSocket-Protocol" ), "\n " )
134+ for _ , sp := range subprotocols {
135+ for _ , cp := range clientSubprotocols {
136+ if sp == strings .TrimSpace (cp ) {
137+ w .Header ().Set ("Sec-WebSocket-Protocol" , sp )
138+ return
139+ }
140+ }
141+ }
142+ }
143+
144+ var keyGUID = []byte ("258EAFA5-E914-47DA-95CA-C5AB0DC85B11" )
145+
146+ func handleKey (w http.ResponseWriter , r * http.Request ) {
147+ key := r .Header .Get ("Sec-WebSocket-Key" )
148+ h := sha1 .New ()
149+ h .Write ([]byte (key ))
150+ h .Write (keyGUID )
151+
152+ responseKey := base64 .StdEncoding .EncodeToString (h .Sum (nil ))
153+ w .Header ().Set ("Sec-WebSocket-Accept" , responseKey )
154+ }
155+
156+ func authenticateOrigin (r * http.Request , origins []string ) error {
157+ origin := r .Header .Get ("Origin" )
158+ if origin == "" {
159+ return nil
160+ }
161+ u , err := url .Parse (origin )
162+ if err != nil {
163+ return xerrors .Errorf ("failed to parse Origin header %q: %v" , origin , err )
164+ }
165+ for _ , o := range origins {
166+ if u .Host == o {
167+ return nil
168+ }
169+ }
170+ return xerrors .New ("request origin is not authorized" )
40171}
0 commit comments