blob: 8975b6e1c891bd75aae67488fe0e9459fa01332b [file] [log] [blame]
Serge Bazanskicc25bdf2018-10-25 14:02:58 +02001// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package middleware
16
17import (
18 "encoding"
19 "encoding/base64"
20 "fmt"
21 "io"
22 "net/http"
23 "reflect"
24 "strconv"
25
26 "github.com/go-openapi/errors"
27 "github.com/go-openapi/runtime"
28 "github.com/go-openapi/spec"
29 "github.com/go-openapi/strfmt"
30 "github.com/go-openapi/swag"
31 "github.com/go-openapi/validate"
32)
33
34const defaultMaxMemory = 32 << 20
35
36var textUnmarshalType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
37
38func newUntypedParamBinder(param spec.Parameter, spec *spec.Swagger, formats strfmt.Registry) *untypedParamBinder {
39 binder := new(untypedParamBinder)
40 binder.Name = param.Name
41 binder.parameter = &param
42 binder.formats = formats
43 if param.In != "body" {
44 binder.validator = validate.NewParamValidator(&param, formats)
45 } else {
46 binder.validator = validate.NewSchemaValidator(param.Schema, spec, param.Name, formats)
47 }
48
49 return binder
50}
51
52type untypedParamBinder struct {
53 parameter *spec.Parameter
54 formats strfmt.Registry
55 Name string
56 validator validate.EntityValidator
57}
58
59func (p *untypedParamBinder) Type() reflect.Type {
60 return p.typeForSchema(p.parameter.Type, p.parameter.Format, p.parameter.Items)
61}
62
63func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type {
64 switch tpe {
65 case "boolean":
66 return reflect.TypeOf(true)
67
68 case "string":
69 if tt, ok := p.formats.GetType(format); ok {
70 return tt
71 }
72 return reflect.TypeOf("")
73
74 case "integer":
75 switch format {
76 case "int8":
77 return reflect.TypeOf(int8(0))
78 case "int16":
79 return reflect.TypeOf(int16(0))
80 case "int32":
81 return reflect.TypeOf(int32(0))
82 case "int64":
83 return reflect.TypeOf(int64(0))
84 default:
85 return reflect.TypeOf(int64(0))
86 }
87
88 case "number":
89 switch format {
90 case "float":
91 return reflect.TypeOf(float32(0))
92 case "double":
93 return reflect.TypeOf(float64(0))
94 }
95
96 case "array":
97 if items == nil {
98 return nil
99 }
100 itemsType := p.typeForSchema(items.Type, items.Format, items.Items)
101 if itemsType == nil {
102 return nil
103 }
104 return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type()
105
106 case "file":
107 return reflect.TypeOf(&runtime.File{}).Elem()
108
109 case "object":
110 return reflect.TypeOf(map[string]interface{}{})
111 }
112 return nil
113}
114
115func (p *untypedParamBinder) allowsMulti() bool {
116 return p.parameter.In == "query" || p.parameter.In == "formData"
117}
118
119func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) {
120 name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type
121 if tpe == "array" {
122 if cf == "multi" {
123 if !p.allowsMulti() {
124 return nil, false, false, errors.InvalidCollectionFormat(name, in, cf)
125 }
126 vv, hasKey, _ := values.GetOK(name)
127 return vv, false, hasKey, nil
128 }
129
130 v, hk, hv := values.GetOK(name)
131 if !hv {
132 return nil, false, hk, nil
133 }
134 d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target)
135 return d, c, hk, e
136 }
137
138 vv, hk, _ := values.GetOK(name)
139 return vv, false, hk, nil
140}
141
142func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error {
143 // fmt.Println("binding", p.name, "as", p.Type())
144 switch p.parameter.In {
145 case "query":
146 data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target)
147 if err != nil {
148 return err
149 }
150 if custom {
151 return nil
152 }
153
154 return p.bindValue(data, hasKey, target)
155
156 case "header":
157 data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target)
158 if err != nil {
159 return err
160 }
161 if custom {
162 return nil
163 }
164 return p.bindValue(data, hasKey, target)
165
166 case "path":
167 data, custom, hasKey, err := p.readValue(routeParams, target)
168 if err != nil {
169 return err
170 }
171 if custom {
172 return nil
173 }
174 return p.bindValue(data, hasKey, target)
175
176 case "formData":
177 var err error
178 var mt string
179
180 mt, _, e := runtime.ContentType(request.Header)
181 if e != nil {
182 // because of the interface conversion go thinks the error is not nil
183 // so we first check for nil and then set the err var if it's not nil
184 err = e
185 }
186
187 if err != nil {
188 return errors.InvalidContentType("", []string{"multipart/form-data", "application/x-www-form-urlencoded"})
189 }
190
191 if mt != "multipart/form-data" && mt != "application/x-www-form-urlencoded" {
192 return errors.InvalidContentType(mt, []string{"multipart/form-data", "application/x-www-form-urlencoded"})
193 }
194
195 if mt == "multipart/form-data" {
196 if err = request.ParseMultipartForm(defaultMaxMemory); err != nil {
197 return errors.NewParseError(p.Name, p.parameter.In, "", err)
198 }
199 }
200
201 if err = request.ParseForm(); err != nil {
202 return errors.NewParseError(p.Name, p.parameter.In, "", err)
203 }
204
205 if p.parameter.Type == "file" {
206 file, header, ffErr := request.FormFile(p.parameter.Name)
207 if ffErr != nil {
208 return errors.NewParseError(p.Name, p.parameter.In, "", ffErr)
209 }
210 target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header}))
211 return nil
212 }
213
214 if request.MultipartForm != nil {
215 data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target)
216 if rvErr != nil {
217 return rvErr
218 }
219 if custom {
220 return nil
221 }
222 return p.bindValue(data, hasKey, target)
223 }
224 data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target)
225 if err != nil {
226 return err
227 }
228 if custom {
229 return nil
230 }
231 return p.bindValue(data, hasKey, target)
232
233 case "body":
234 newValue := reflect.New(target.Type())
235 if !runtime.HasBody(request) {
236 if p.parameter.Default != nil {
237 target.Set(reflect.ValueOf(p.parameter.Default))
238 }
239
240 return nil
241 }
242 if err := consumer.Consume(request.Body, newValue.Interface()); err != nil {
243 if err == io.EOF && p.parameter.Default != nil {
244 target.Set(reflect.ValueOf(p.parameter.Default))
245 return nil
246 }
247 tpe := p.parameter.Type
248 if p.parameter.Format != "" {
249 tpe = p.parameter.Format
250 }
251 return errors.InvalidType(p.Name, p.parameter.In, tpe, nil)
252 }
253 target.Set(reflect.Indirect(newValue))
254 return nil
255 default:
256 return errors.New(500, fmt.Sprintf("invalid parameter location %q", p.parameter.In))
257 }
258}
259
260func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflect.Value) error {
261 if p.parameter.Type == "array" {
262 return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey)
263 }
264 var d string
265 if len(data) > 0 {
266 d = data[len(data)-1]
267 }
268 return p.setFieldValue(target, p.parameter.Default, d, hasKey)
269}
270
271func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue interface{}, data string, hasKey bool) error {
272 tpe := p.parameter.Type
273 if p.parameter.Format != "" {
274 tpe = p.parameter.Format
275 }
276
277 if (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) && p.parameter.Required && p.parameter.Default == nil {
278 return errors.Required(p.Name, p.parameter.In)
279 }
280
281 ok, err := p.tryUnmarshaler(target, defaultValue, data)
282 if err != nil {
283 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
284 }
285 if ok {
286 return nil
287 }
288
289 defVal := reflect.Zero(target.Type())
290 if defaultValue != nil {
291 defVal = reflect.ValueOf(defaultValue)
292 }
293
294 if tpe == "byte" {
295 if data == "" {
296 if target.CanSet() {
297 target.SetBytes(defVal.Bytes())
298 }
299 return nil
300 }
301
302 b, err := base64.StdEncoding.DecodeString(data)
303 if err != nil {
304 b, err = base64.URLEncoding.DecodeString(data)
305 if err != nil {
306 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
307 }
308 }
309 if target.CanSet() {
310 target.SetBytes(b)
311 }
312 return nil
313 }
314
315 switch target.Kind() {
316 case reflect.Bool:
317 if data == "" {
318 if target.CanSet() {
319 target.SetBool(defVal.Bool())
320 }
321 return nil
322 }
323 b, err := swag.ConvertBool(data)
324 if err != nil {
325 return err
326 }
327 if target.CanSet() {
328 target.SetBool(b)
329 }
330 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
331 if data == "" {
332 if target.CanSet() {
333 rd := defVal.Convert(reflect.TypeOf(int64(0)))
334 target.SetInt(rd.Int())
335 }
336 return nil
337 }
338 i, err := strconv.ParseInt(data, 10, 64)
339 if err != nil {
340 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
341 }
342 if target.OverflowInt(i) {
343 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
344 }
345 if target.CanSet() {
346 target.SetInt(i)
347 }
348
349 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
350 if data == "" {
351 if target.CanSet() {
352 rd := defVal.Convert(reflect.TypeOf(uint64(0)))
353 target.SetUint(rd.Uint())
354 }
355 return nil
356 }
357 u, err := strconv.ParseUint(data, 10, 64)
358 if err != nil {
359 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
360 }
361 if target.OverflowUint(u) {
362 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
363 }
364 if target.CanSet() {
365 target.SetUint(u)
366 }
367
368 case reflect.Float32, reflect.Float64:
369 if data == "" {
370 if target.CanSet() {
371 rd := defVal.Convert(reflect.TypeOf(float64(0)))
372 target.SetFloat(rd.Float())
373 }
374 return nil
375 }
376 f, err := strconv.ParseFloat(data, 64)
377 if err != nil {
378 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
379 }
380 if target.OverflowFloat(f) {
381 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
382 }
383 if target.CanSet() {
384 target.SetFloat(f)
385 }
386
387 case reflect.String:
388 value := data
389 if value == "" {
390 value = defVal.String()
391 }
392 // validate string
393 if target.CanSet() {
394 target.SetString(value)
395 }
396
397 case reflect.Ptr:
398 if data == "" && defVal.Kind() == reflect.Ptr {
399 if target.CanSet() {
400 target.Set(defVal)
401 }
402 return nil
403 }
404 newVal := reflect.New(target.Type().Elem())
405 if err := p.setFieldValue(reflect.Indirect(newVal), defVal, data, hasKey); err != nil {
406 return err
407 }
408 if target.CanSet() {
409 target.Set(newVal)
410 }
411
412 default:
413 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
414 }
415 return nil
416}
417
418func (p *untypedParamBinder) tryUnmarshaler(target reflect.Value, defaultValue interface{}, data string) (bool, error) {
419 if !target.CanSet() {
420 return false, nil
421 }
422 // When a type implements encoding.TextUnmarshaler we'll use that instead of reflecting some more
423 if reflect.PtrTo(target.Type()).Implements(textUnmarshalType) {
424 if defaultValue != nil && len(data) == 0 {
425 target.Set(reflect.ValueOf(defaultValue))
426 return true, nil
427 }
428 value := reflect.New(target.Type())
429 if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil {
430 return true, err
431 }
432 target.Set(reflect.Indirect(value))
433 return true, nil
434 }
435 return false, nil
436}
437
438func (p *untypedParamBinder) readFormattedSliceFieldValue(data string, target reflect.Value) ([]string, bool, error) {
439 ok, err := p.tryUnmarshaler(target, p.parameter.Default, data)
440 if err != nil {
441 return nil, true, err
442 }
443 if ok {
444 return nil, true, nil
445 }
446
447 return swag.SplitByFormat(data, p.parameter.CollectionFormat), false, nil
448}
449
450func (p *untypedParamBinder) setSliceFieldValue(target reflect.Value, defaultValue interface{}, data []string, hasKey bool) error {
451 sz := len(data)
452 if (!hasKey || (!p.parameter.AllowEmptyValue && (sz == 0 || (sz == 1 && data[0] == "")))) && p.parameter.Required && defaultValue == nil {
453 return errors.Required(p.Name, p.parameter.In)
454 }
455
456 defVal := reflect.Zero(target.Type())
457 if defaultValue != nil {
458 defVal = reflect.ValueOf(defaultValue)
459 }
460
461 if !target.CanSet() {
462 return nil
463 }
464 if sz == 0 {
465 target.Set(defVal)
466 return nil
467 }
468
469 value := reflect.MakeSlice(reflect.SliceOf(target.Type().Elem()), sz, sz)
470
471 for i := 0; i < sz; i++ {
472 if err := p.setFieldValue(value.Index(i), nil, data[i], hasKey); err != nil {
473 return err
474 }
475 }
476
477 target.Set(value)
478
479 return nil
480}