| // Copyright 2015 go-swagger maintainers |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package middleware |
| |
| import ( |
| stdContext "context" |
| "net/http" |
| "strings" |
| "sync" |
| |
| "github.com/go-openapi/analysis" |
| "github.com/go-openapi/errors" |
| "github.com/go-openapi/loads" |
| "github.com/go-openapi/runtime" |
| "github.com/go-openapi/runtime/logger" |
| "github.com/go-openapi/runtime/middleware/untyped" |
| "github.com/go-openapi/spec" |
| "github.com/go-openapi/strfmt" |
| ) |
| |
| // Debug when true turns on verbose logging |
| var Debug = logger.DebugEnabled() |
| var Logger logger.Logger = logger.StandardLogger{} |
| |
| func debugLog(format string, args ...interface{}) { |
| if Debug { |
| Logger.Printf(format, args...) |
| } |
| } |
| |
| // A Builder can create middlewares |
| type Builder func(http.Handler) http.Handler |
| |
| // PassthroughBuilder returns the handler, aka the builder identity function |
| func PassthroughBuilder(handler http.Handler) http.Handler { return handler } |
| |
| // RequestBinder is an interface for types to implement |
| // when they want to be able to bind from a request |
| type RequestBinder interface { |
| BindRequest(*http.Request, *MatchedRoute) error |
| } |
| |
| // Responder is an interface for types to implement |
| // when they want to be considered for writing HTTP responses |
| type Responder interface { |
| WriteResponse(http.ResponseWriter, runtime.Producer) |
| } |
| |
| // ResponderFunc wraps a func as a Responder interface |
| type ResponderFunc func(http.ResponseWriter, runtime.Producer) |
| |
| // WriteResponse writes to the response |
| func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) { |
| fn(rw, pr) |
| } |
| |
| // Context is a type safe wrapper around an untyped request context |
| // used throughout to store request context with the standard context attached |
| // to the http.Request |
| type Context struct { |
| spec *loads.Document |
| analyzer *analysis.Spec |
| api RoutableAPI |
| router Router |
| } |
| |
| type routableUntypedAPI struct { |
| api *untyped.API |
| hlock *sync.Mutex |
| handlers map[string]map[string]http.Handler |
| defaultConsumes string |
| defaultProduces string |
| } |
| |
| func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI { |
| var handlers map[string]map[string]http.Handler |
| if spec == nil || api == nil { |
| return nil |
| } |
| analyzer := analysis.New(spec.Spec()) |
| for method, hls := range analyzer.Operations() { |
| um := strings.ToUpper(method) |
| for path, op := range hls { |
| schemes := analyzer.SecurityRequirementsFor(op) |
| |
| if oh, ok := api.OperationHandlerFor(method, path); ok { |
| if handlers == nil { |
| handlers = make(map[string]map[string]http.Handler) |
| } |
| if b, ok := handlers[um]; !ok || b == nil { |
| handlers[um] = make(map[string]http.Handler) |
| } |
| |
| var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| // lookup route info in the context |
| route, rCtx, _ := context.RouteInfo(r) |
| if rCtx != nil { |
| r = rCtx |
| } |
| |
| // bind and validate the request using reflection |
| var bound interface{} |
| var validation error |
| bound, r, validation = context.BindAndValidate(r, route) |
| if validation != nil { |
| context.Respond(w, r, route.Produces, route, validation) |
| return |
| } |
| |
| // actually handle the request |
| result, err := oh.Handle(bound) |
| if err != nil { |
| // respond with failure |
| context.Respond(w, r, route.Produces, route, err) |
| return |
| } |
| |
| // respond with success |
| context.Respond(w, r, route.Produces, route, result) |
| }) |
| |
| if len(schemes) > 0 { |
| handler = newSecureAPI(context, handler) |
| } |
| handlers[um][path] = handler |
| } |
| } |
| } |
| |
| return &routableUntypedAPI{ |
| api: api, |
| hlock: new(sync.Mutex), |
| handlers: handlers, |
| defaultProduces: api.DefaultProduces, |
| defaultConsumes: api.DefaultConsumes, |
| } |
| } |
| |
| func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) { |
| r.hlock.Lock() |
| paths, ok := r.handlers[strings.ToUpper(method)] |
| if !ok { |
| r.hlock.Unlock() |
| return nil, false |
| } |
| handler, ok := paths[path] |
| r.hlock.Unlock() |
| return handler, ok |
| } |
| func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) { |
| return r.api.ServeError |
| } |
| func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer { |
| return r.api.ConsumersFor(mediaTypes) |
| } |
| func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer { |
| return r.api.ProducersFor(mediaTypes) |
| } |
| func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator { |
| return r.api.AuthenticatorsFor(schemes) |
| } |
| func (r *routableUntypedAPI) Authorizer() runtime.Authorizer { |
| return r.api.Authorizer() |
| } |
| func (r *routableUntypedAPI) Formats() strfmt.Registry { |
| return r.api.Formats() |
| } |
| |
| func (r *routableUntypedAPI) DefaultProduces() string { |
| return r.defaultProduces |
| } |
| |
| func (r *routableUntypedAPI) DefaultConsumes() string { |
| return r.defaultConsumes |
| } |
| |
| // NewRoutableContext creates a new context for a routable API |
| func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context { |
| var an *analysis.Spec |
| if spec != nil { |
| an = analysis.New(spec.Spec()) |
| } |
| ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes} |
| return ctx |
| } |
| |
| // NewContext creates a new context wrapper |
| func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context { |
| var an *analysis.Spec |
| if spec != nil { |
| an = analysis.New(spec.Spec()) |
| } |
| ctx := &Context{spec: spec, analyzer: an} |
| ctx.api = newRoutableUntypedAPI(spec, api, ctx) |
| ctx.router = routes |
| return ctx |
| } |
| |
| // Serve serves the specified spec with the specified api registrations as a http.Handler |
| func Serve(spec *loads.Document, api *untyped.API) http.Handler { |
| return ServeWithBuilder(spec, api, PassthroughBuilder) |
| } |
| |
| // ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated |
| // by the Builder |
| func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler { |
| context := NewContext(spec, api, nil) |
| return context.APIHandler(builder) |
| } |
| |
| type contextKey int8 |
| |
| const ( |
| _ contextKey = iota |
| ctxContentType |
| ctxResponseFormat |
| ctxMatchedRoute |
| ctxBoundParams |
| ctxSecurityPrincipal |
| ctxSecurityScopes |
| ) |
| |
| // MatchedRouteFrom request context value. |
| func MatchedRouteFrom(req *http.Request) *MatchedRoute { |
| mr := req.Context().Value(ctxMatchedRoute) |
| if mr == nil { |
| return nil |
| } |
| if res, ok := mr.(*MatchedRoute); ok { |
| return res |
| } |
| return nil |
| } |
| |
| // SecurityPrincipalFrom request context value. |
| func SecurityPrincipalFrom(req *http.Request) interface{} { |
| return req.Context().Value(ctxSecurityPrincipal) |
| } |
| |
| // SecurityScopesFrom request context value. |
| func SecurityScopesFrom(req *http.Request) []string { |
| rs := req.Context().Value(ctxSecurityScopes) |
| if res, ok := rs.([]string); ok { |
| return res |
| } |
| return nil |
| } |
| |
| type contentTypeValue struct { |
| MediaType string |
| Charset string |
| } |
| |
| // BasePath returns the base path for this API |
| func (c *Context) BasePath() string { |
| return c.spec.BasePath() |
| } |
| |
| // RequiredProduces returns the accepted content types for responses |
| func (c *Context) RequiredProduces() []string { |
| return c.analyzer.RequiredProduces() |
| } |
| |
| // BindValidRequest binds a params object to a request but only when the request is valid |
| // if the request is not valid an error will be returned |
| func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error { |
| var res []error |
| |
| requestContentType := "*/*" |
| // check and validate content type, select consumer |
| if runtime.HasBody(request) { |
| ct, _, err := runtime.ContentType(request.Header) |
| if err != nil { |
| res = append(res, err) |
| } else { |
| if err := validateContentType(route.Consumes, ct); err != nil { |
| res = append(res, err) |
| } |
| if len(res) == 0 { |
| cons, ok := route.Consumers[ct] |
| if !ok { |
| res = append(res, errors.New(500, "no consumer registered for %s", ct)) |
| } else { |
| route.Consumer = cons |
| requestContentType = ct |
| } |
| } |
| } |
| } |
| |
| // check and validate the response format |
| if len(res) == 0 && runtime.HasBody(request) { |
| if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" { |
| res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces)) |
| } |
| } |
| |
| // now bind the request with the provided binder |
| // it's assumed the binder will also validate the request and return an error if the |
| // request is invalid |
| if binder != nil && len(res) == 0 { |
| if err := binder.BindRequest(request, route); err != nil { |
| return err |
| } |
| } |
| |
| if len(res) > 0 { |
| return errors.CompositeValidationError(res...) |
| } |
| return nil |
| } |
| |
| // ContentType gets the parsed value of a content type |
| // Returns the media type, its charset and a shallow copy of the request |
| // when its context doesn't contain the content type value, otherwise it returns |
| // the same request |
| // Returns the error that runtime.ContentType may retunrs. |
| func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) { |
| var rCtx = request.Context() |
| |
| if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok { |
| return v.MediaType, v.Charset, request, nil |
| } |
| |
| mt, cs, err := runtime.ContentType(request.Header) |
| if err != nil { |
| return "", "", nil, err |
| } |
| rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs}) |
| return mt, cs, request.WithContext(rCtx), nil |
| } |
| |
| // LookupRoute looks a route up and returns true when it is found |
| func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) { |
| if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok { |
| return route, ok |
| } |
| return nil, false |
| } |
| |
| // RouteInfo tries to match a route for this request |
| // Returns the matched route, a shallow copy of the request if its context |
| // contains the matched router, otherwise the same request, and a bool to |
| // indicate if it the request matches one of the routes, if it doesn't |
| // then it returns false and nil for the other two return values |
| func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) { |
| var rCtx = request.Context() |
| |
| if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok { |
| return v, request, ok |
| } |
| |
| if route, ok := c.LookupRoute(request); ok { |
| rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route) |
| return route, request.WithContext(rCtx), ok |
| } |
| |
| return nil, nil, false |
| } |
| |
| // ResponseFormat negotiates the response content type |
| // Returns the response format and a shallow copy of the request if its context |
| // doesn't contain the response format, otherwise the same request |
| func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) { |
| var rCtx = r.Context() |
| |
| if v, ok := rCtx.Value(ctxResponseFormat).(string); ok { |
| debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v) |
| return v, r |
| } |
| |
| format := NegotiateContentType(r, offers, "") |
| if format != "" { |
| debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format) |
| r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format)) |
| } |
| debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format) |
| return format, r |
| } |
| |
| // AllowedMethods gets the allowed methods for the path of this request |
| func (c *Context) AllowedMethods(request *http.Request) []string { |
| return c.router.OtherMethods(request.Method, request.URL.EscapedPath()) |
| } |
| |
| // ResetAuth removes the current principal from the request context |
| func (c *Context) ResetAuth(request *http.Request) *http.Request { |
| rctx := request.Context() |
| rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil) |
| rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil) |
| return request.WithContext(rctx) |
| } |
| |
| // Authorize authorizes the request |
| // Returns the principal object and a shallow copy of the request when its |
| // context doesn't contain the principal, otherwise the same request or an error |
| // (the last) if one of the authenticators returns one or an Unauthenticated error |
| func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) { |
| if route == nil || !route.HasAuth() { |
| return nil, nil, nil |
| } |
| |
| var rCtx = request.Context() |
| if v := rCtx.Value(ctxSecurityPrincipal); v != nil { |
| return v, request, nil |
| } |
| |
| applies, usr, err := route.Authenticators.Authenticate(request, route) |
| if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil { |
| if err != nil { |
| return nil, nil, err |
| } |
| return nil, nil, errors.Unauthenticated("invalid credentials") |
| } |
| if route.Authorizer != nil { |
| if err := route.Authorizer.Authorize(request, usr); err != nil { |
| return nil, nil, errors.New(http.StatusForbidden, err.Error()) |
| } |
| } |
| |
| rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr) |
| rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes()) |
| return usr, request.WithContext(rCtx), nil |
| } |
| |
| // BindAndValidate binds and validates the request |
| // Returns the validation map and a shallow copy of the request when its context |
| // doesn't contain the validation, otherwise it returns the same request or an |
| // CompositeValidationError error |
| func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) { |
| var rCtx = request.Context() |
| |
| if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok { |
| debugLog("got cached validation (valid: %t)", len(v.result) == 0) |
| if len(v.result) > 0 { |
| return v.bound, request, errors.CompositeValidationError(v.result...) |
| } |
| return v.bound, request, nil |
| } |
| result := validateRequest(c, request, matched) |
| rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result) |
| request = request.WithContext(rCtx) |
| if len(result.result) > 0 { |
| return result.bound, request, errors.CompositeValidationError(result.result...) |
| } |
| debugLog("no validation errors found") |
| return result.bound, request, nil |
| } |
| |
| // NotFound the default not found responder for when no route has been matched yet |
| func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) { |
| c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found")) |
| } |
| |
| // Respond renders the response after doing some content negotiation |
| func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) { |
| debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces) |
| offers := []string{} |
| for _, mt := range produces { |
| if mt != c.api.DefaultProduces() { |
| offers = append(offers, mt) |
| } |
| } |
| // the default producer is last so more specific producers take precedence |
| offers = append(offers, c.api.DefaultProduces()) |
| debugLog("offers: %v", offers) |
| |
| var format string |
| format, r = c.ResponseFormat(r, offers) |
| rw.Header().Set(runtime.HeaderContentType, format) |
| |
| if resp, ok := data.(Responder); ok { |
| producers := route.Producers |
| prod, ok := producers[format] |
| if !ok { |
| prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) |
| pr, ok := prods[c.api.DefaultProduces()] |
| if !ok { |
| panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) |
| } |
| prod = pr |
| } |
| resp.WriteResponse(rw, prod) |
| return |
| } |
| |
| if err, ok := data.(error); ok { |
| if format == "" { |
| rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime) |
| } |
| if route == nil || route.Operation == nil { |
| c.api.ServeErrorFor("")(rw, r, err) |
| return |
| } |
| c.api.ServeErrorFor(route.Operation.ID)(rw, r, err) |
| return |
| } |
| |
| if route == nil || route.Operation == nil { |
| rw.WriteHeader(200) |
| if r.Method == "HEAD" { |
| return |
| } |
| producers := c.api.ProducersFor(normalizeOffers(offers)) |
| prod, ok := producers[format] |
| if !ok { |
| panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) |
| } |
| if err := prod.Produce(rw, data); err != nil { |
| panic(err) // let the recovery middleware deal with this |
| } |
| return |
| } |
| |
| if _, code, ok := route.Operation.SuccessResponse(); ok { |
| rw.WriteHeader(code) |
| if code == 204 || r.Method == "HEAD" { |
| return |
| } |
| |
| producers := route.Producers |
| prod, ok := producers[format] |
| if !ok { |
| if !ok { |
| prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) |
| pr, ok := prods[c.api.DefaultProduces()] |
| if !ok { |
| panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) |
| } |
| prod = pr |
| } |
| } |
| if err := prod.Produce(rw, data); err != nil { |
| panic(err) // let the recovery middleware deal with this |
| } |
| return |
| } |
| |
| c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response")) |
| } |
| |
| // APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec |
| func (c *Context) APIHandler(builder Builder) http.Handler { |
| b := builder |
| if b == nil { |
| b = PassthroughBuilder |
| } |
| |
| var title string |
| sp := c.spec.Spec() |
| if sp != nil && sp.Info != nil && sp.Info.Title != "" { |
| title = sp.Info.Title |
| } |
| |
| redocOpts := RedocOpts{ |
| BasePath: c.BasePath(), |
| Title: title, |
| } |
| |
| return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b))) |
| } |
| |
| // RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec |
| func (c *Context) RoutesHandler(builder Builder) http.Handler { |
| b := builder |
| if b == nil { |
| b = PassthroughBuilder |
| } |
| return NewRouter(c, b(NewOperationExecutor(c))) |
| } |