blob: 48d7150df43840a5ac0febaa95e5c3bd093895eb [file] [log] [blame]
Serge Bazanskicc25bdf2018-10-25 14:02:58 +02001// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package middleware
16
17import (
18 stdContext "context"
19 "net/http"
20 "strings"
21 "sync"
22
23 "github.com/go-openapi/analysis"
24 "github.com/go-openapi/errors"
25 "github.com/go-openapi/loads"
26 "github.com/go-openapi/runtime"
27 "github.com/go-openapi/runtime/logger"
28 "github.com/go-openapi/runtime/middleware/untyped"
29 "github.com/go-openapi/spec"
30 "github.com/go-openapi/strfmt"
31)
32
33// Debug when true turns on verbose logging
34var Debug = logger.DebugEnabled()
35var Logger logger.Logger = logger.StandardLogger{}
36
37func debugLog(format string, args ...interface{}) {
38 if Debug {
39 Logger.Printf(format, args...)
40 }
41}
42
43// A Builder can create middlewares
44type Builder func(http.Handler) http.Handler
45
46// PassthroughBuilder returns the handler, aka the builder identity function
47func PassthroughBuilder(handler http.Handler) http.Handler { return handler }
48
49// RequestBinder is an interface for types to implement
50// when they want to be able to bind from a request
51type RequestBinder interface {
52 BindRequest(*http.Request, *MatchedRoute) error
53}
54
55// Responder is an interface for types to implement
56// when they want to be considered for writing HTTP responses
57type Responder interface {
58 WriteResponse(http.ResponseWriter, runtime.Producer)
59}
60
61// ResponderFunc wraps a func as a Responder interface
62type ResponderFunc func(http.ResponseWriter, runtime.Producer)
63
64// WriteResponse writes to the response
65func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) {
66 fn(rw, pr)
67}
68
69// Context is a type safe wrapper around an untyped request context
70// used throughout to store request context with the standard context attached
71// to the http.Request
72type Context struct {
73 spec *loads.Document
74 analyzer *analysis.Spec
75 api RoutableAPI
76 router Router
77}
78
79type routableUntypedAPI struct {
80 api *untyped.API
81 hlock *sync.Mutex
82 handlers map[string]map[string]http.Handler
83 defaultConsumes string
84 defaultProduces string
85}
86
87func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI {
88 var handlers map[string]map[string]http.Handler
89 if spec == nil || api == nil {
90 return nil
91 }
92 analyzer := analysis.New(spec.Spec())
93 for method, hls := range analyzer.Operations() {
94 um := strings.ToUpper(method)
95 for path, op := range hls {
96 schemes := analyzer.SecurityRequirementsFor(op)
97
98 if oh, ok := api.OperationHandlerFor(method, path); ok {
99 if handlers == nil {
100 handlers = make(map[string]map[string]http.Handler)
101 }
102 if b, ok := handlers[um]; !ok || b == nil {
103 handlers[um] = make(map[string]http.Handler)
104 }
105
106 var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107 // lookup route info in the context
108 route, rCtx, _ := context.RouteInfo(r)
109 if rCtx != nil {
110 r = rCtx
111 }
112
113 // bind and validate the request using reflection
114 var bound interface{}
115 var validation error
116 bound, r, validation = context.BindAndValidate(r, route)
117 if validation != nil {
118 context.Respond(w, r, route.Produces, route, validation)
119 return
120 }
121
122 // actually handle the request
123 result, err := oh.Handle(bound)
124 if err != nil {
125 // respond with failure
126 context.Respond(w, r, route.Produces, route, err)
127 return
128 }
129
130 // respond with success
131 context.Respond(w, r, route.Produces, route, result)
132 })
133
134 if len(schemes) > 0 {
135 handler = newSecureAPI(context, handler)
136 }
137 handlers[um][path] = handler
138 }
139 }
140 }
141
142 return &routableUntypedAPI{
143 api: api,
144 hlock: new(sync.Mutex),
145 handlers: handlers,
146 defaultProduces: api.DefaultProduces,
147 defaultConsumes: api.DefaultConsumes,
148 }
149}
150
151func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) {
152 r.hlock.Lock()
153 paths, ok := r.handlers[strings.ToUpper(method)]
154 if !ok {
155 r.hlock.Unlock()
156 return nil, false
157 }
158 handler, ok := paths[path]
159 r.hlock.Unlock()
160 return handler, ok
161}
162func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) {
163 return r.api.ServeError
164}
165func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
166 return r.api.ConsumersFor(mediaTypes)
167}
168func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
169 return r.api.ProducersFor(mediaTypes)
170}
171func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
172 return r.api.AuthenticatorsFor(schemes)
173}
174func (r *routableUntypedAPI) Authorizer() runtime.Authorizer {
175 return r.api.Authorizer()
176}
177func (r *routableUntypedAPI) Formats() strfmt.Registry {
178 return r.api.Formats()
179}
180
181func (r *routableUntypedAPI) DefaultProduces() string {
182 return r.defaultProduces
183}
184
185func (r *routableUntypedAPI) DefaultConsumes() string {
186 return r.defaultConsumes
187}
188
189// NewRoutableContext creates a new context for a routable API
190func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context {
191 var an *analysis.Spec
192 if spec != nil {
193 an = analysis.New(spec.Spec())
194 }
195 ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes}
196 return ctx
197}
198
199// NewContext creates a new context wrapper
200func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context {
201 var an *analysis.Spec
202 if spec != nil {
203 an = analysis.New(spec.Spec())
204 }
205 ctx := &Context{spec: spec, analyzer: an}
206 ctx.api = newRoutableUntypedAPI(spec, api, ctx)
207 ctx.router = routes
208 return ctx
209}
210
211// Serve serves the specified spec with the specified api registrations as a http.Handler
212func Serve(spec *loads.Document, api *untyped.API) http.Handler {
213 return ServeWithBuilder(spec, api, PassthroughBuilder)
214}
215
216// ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated
217// by the Builder
218func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler {
219 context := NewContext(spec, api, nil)
220 return context.APIHandler(builder)
221}
222
223type contextKey int8
224
225const (
226 _ contextKey = iota
227 ctxContentType
228 ctxResponseFormat
229 ctxMatchedRoute
230 ctxBoundParams
231 ctxSecurityPrincipal
232 ctxSecurityScopes
233)
234
235// MatchedRouteFrom request context value.
236func MatchedRouteFrom(req *http.Request) *MatchedRoute {
237 mr := req.Context().Value(ctxMatchedRoute)
238 if mr == nil {
239 return nil
240 }
241 if res, ok := mr.(*MatchedRoute); ok {
242 return res
243 }
244 return nil
245}
246
247// SecurityPrincipalFrom request context value.
248func SecurityPrincipalFrom(req *http.Request) interface{} {
249 return req.Context().Value(ctxSecurityPrincipal)
250}
251
252// SecurityScopesFrom request context value.
253func SecurityScopesFrom(req *http.Request) []string {
254 rs := req.Context().Value(ctxSecurityScopes)
255 if res, ok := rs.([]string); ok {
256 return res
257 }
258 return nil
259}
260
261type contentTypeValue struct {
262 MediaType string
263 Charset string
264}
265
266// BasePath returns the base path for this API
267func (c *Context) BasePath() string {
268 return c.spec.BasePath()
269}
270
271// RequiredProduces returns the accepted content types for responses
272func (c *Context) RequiredProduces() []string {
273 return c.analyzer.RequiredProduces()
274}
275
276// BindValidRequest binds a params object to a request but only when the request is valid
277// if the request is not valid an error will be returned
278func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error {
279 var res []error
280
281 requestContentType := "*/*"
282 // check and validate content type, select consumer
283 if runtime.HasBody(request) {
284 ct, _, err := runtime.ContentType(request.Header)
285 if err != nil {
286 res = append(res, err)
287 } else {
288 if err := validateContentType(route.Consumes, ct); err != nil {
289 res = append(res, err)
290 }
291 if len(res) == 0 {
292 cons, ok := route.Consumers[ct]
293 if !ok {
294 res = append(res, errors.New(500, "no consumer registered for %s", ct))
295 } else {
296 route.Consumer = cons
297 requestContentType = ct
298 }
299 }
300 }
301 }
302
303 // check and validate the response format
304 if len(res) == 0 && runtime.HasBody(request) {
305 if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" {
306 res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces))
307 }
308 }
309
310 // now bind the request with the provided binder
311 // it's assumed the binder will also validate the request and return an error if the
312 // request is invalid
313 if binder != nil && len(res) == 0 {
314 if err := binder.BindRequest(request, route); err != nil {
315 return err
316 }
317 }
318
319 if len(res) > 0 {
320 return errors.CompositeValidationError(res...)
321 }
322 return nil
323}
324
325// ContentType gets the parsed value of a content type
326// Returns the media type, its charset and a shallow copy of the request
327// when its context doesn't contain the content type value, otherwise it returns
328// the same request
329// Returns the error that runtime.ContentType may retunrs.
330func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) {
331 var rCtx = request.Context()
332
333 if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok {
334 return v.MediaType, v.Charset, request, nil
335 }
336
337 mt, cs, err := runtime.ContentType(request.Header)
338 if err != nil {
339 return "", "", nil, err
340 }
341 rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs})
342 return mt, cs, request.WithContext(rCtx), nil
343}
344
345// LookupRoute looks a route up and returns true when it is found
346func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) {
347 if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok {
348 return route, ok
349 }
350 return nil, false
351}
352
353// RouteInfo tries to match a route for this request
354// Returns the matched route, a shallow copy of the request if its context
355// contains the matched router, otherwise the same request, and a bool to
356// indicate if it the request matches one of the routes, if it doesn't
357// then it returns false and nil for the other two return values
358func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) {
359 var rCtx = request.Context()
360
361 if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok {
362 return v, request, ok
363 }
364
365 if route, ok := c.LookupRoute(request); ok {
366 rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route)
367 return route, request.WithContext(rCtx), ok
368 }
369
370 return nil, nil, false
371}
372
373// ResponseFormat negotiates the response content type
374// Returns the response format and a shallow copy of the request if its context
375// doesn't contain the response format, otherwise the same request
376func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) {
377 var rCtx = r.Context()
378
379 if v, ok := rCtx.Value(ctxResponseFormat).(string); ok {
380 debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v)
381 return v, r
382 }
383
384 format := NegotiateContentType(r, offers, "")
385 if format != "" {
386 debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format)
387 r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format))
388 }
389 debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format)
390 return format, r
391}
392
393// AllowedMethods gets the allowed methods for the path of this request
394func (c *Context) AllowedMethods(request *http.Request) []string {
395 return c.router.OtherMethods(request.Method, request.URL.EscapedPath())
396}
397
398// ResetAuth removes the current principal from the request context
399func (c *Context) ResetAuth(request *http.Request) *http.Request {
400 rctx := request.Context()
401 rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil)
402 rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil)
403 return request.WithContext(rctx)
404}
405
406// Authorize authorizes the request
407// Returns the principal object and a shallow copy of the request when its
408// context doesn't contain the principal, otherwise the same request or an error
409// (the last) if one of the authenticators returns one or an Unauthenticated error
410func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) {
411 if route == nil || !route.HasAuth() {
412 return nil, nil, nil
413 }
414
415 var rCtx = request.Context()
416 if v := rCtx.Value(ctxSecurityPrincipal); v != nil {
417 return v, request, nil
418 }
419
420 applies, usr, err := route.Authenticators.Authenticate(request, route)
421 if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil {
422 if err != nil {
423 return nil, nil, err
424 }
425 return nil, nil, errors.Unauthenticated("invalid credentials")
426 }
427 if route.Authorizer != nil {
428 if err := route.Authorizer.Authorize(request, usr); err != nil {
429 return nil, nil, errors.New(http.StatusForbidden, err.Error())
430 }
431 }
432
433 rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr)
434 rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes())
435 return usr, request.WithContext(rCtx), nil
436}
437
438// BindAndValidate binds and validates the request
439// Returns the validation map and a shallow copy of the request when its context
440// doesn't contain the validation, otherwise it returns the same request or an
441// CompositeValidationError error
442func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) {
443 var rCtx = request.Context()
444
445 if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok {
446 debugLog("got cached validation (valid: %t)", len(v.result) == 0)
447 if len(v.result) > 0 {
448 return v.bound, request, errors.CompositeValidationError(v.result...)
449 }
450 return v.bound, request, nil
451 }
452 result := validateRequest(c, request, matched)
453 rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result)
454 request = request.WithContext(rCtx)
455 if len(result.result) > 0 {
456 return result.bound, request, errors.CompositeValidationError(result.result...)
457 }
458 debugLog("no validation errors found")
459 return result.bound, request, nil
460}
461
462// NotFound the default not found responder for when no route has been matched yet
463func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) {
464 c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found"))
465}
466
467// Respond renders the response after doing some content negotiation
468func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) {
469 debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces)
470 offers := []string{}
471 for _, mt := range produces {
472 if mt != c.api.DefaultProduces() {
473 offers = append(offers, mt)
474 }
475 }
476 // the default producer is last so more specific producers take precedence
477 offers = append(offers, c.api.DefaultProduces())
478 debugLog("offers: %v", offers)
479
480 var format string
481 format, r = c.ResponseFormat(r, offers)
482 rw.Header().Set(runtime.HeaderContentType, format)
483
484 if resp, ok := data.(Responder); ok {
485 producers := route.Producers
486 prod, ok := producers[format]
487 if !ok {
488 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
489 pr, ok := prods[c.api.DefaultProduces()]
490 if !ok {
491 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
492 }
493 prod = pr
494 }
495 resp.WriteResponse(rw, prod)
496 return
497 }
498
499 if err, ok := data.(error); ok {
500 if format == "" {
501 rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime)
502 }
503 if route == nil || route.Operation == nil {
504 c.api.ServeErrorFor("")(rw, r, err)
505 return
506 }
507 c.api.ServeErrorFor(route.Operation.ID)(rw, r, err)
508 return
509 }
510
511 if route == nil || route.Operation == nil {
512 rw.WriteHeader(200)
513 if r.Method == "HEAD" {
514 return
515 }
516 producers := c.api.ProducersFor(normalizeOffers(offers))
517 prod, ok := producers[format]
518 if !ok {
519 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
520 }
521 if err := prod.Produce(rw, data); err != nil {
522 panic(err) // let the recovery middleware deal with this
523 }
524 return
525 }
526
527 if _, code, ok := route.Operation.SuccessResponse(); ok {
528 rw.WriteHeader(code)
529 if code == 204 || r.Method == "HEAD" {
530 return
531 }
532
533 producers := route.Producers
534 prod, ok := producers[format]
535 if !ok {
536 if !ok {
537 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
538 pr, ok := prods[c.api.DefaultProduces()]
539 if !ok {
540 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
541 }
542 prod = pr
543 }
544 }
545 if err := prod.Produce(rw, data); err != nil {
546 panic(err) // let the recovery middleware deal with this
547 }
548 return
549 }
550
551 c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
552}
553
554// APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec
555func (c *Context) APIHandler(builder Builder) http.Handler {
556 b := builder
557 if b == nil {
558 b = PassthroughBuilder
559 }
560
561 var title string
562 sp := c.spec.Spec()
563 if sp != nil && sp.Info != nil && sp.Info.Title != "" {
564 title = sp.Info.Title
565 }
566
567 redocOpts := RedocOpts{
568 BasePath: c.BasePath(),
569 Title: title,
570 }
571
572 return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)))
573}
574
575// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec
576func (c *Context) RoutesHandler(builder Builder) http.Handler {
577 b := builder
578 if b == nil {
579 b = PassthroughBuilder
580 }
581 return NewRouter(c, b(NewOperationExecutor(c)))
582}