grim/resticide

Correctly support query parameters
develop
2016-07-09, Gary Kramlich
92fb02aef02c
Correctly support query parameters
package test
import (
"bytes"
"crypto/md5"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime"
"mime/multipart"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"gopkg.in/xeipuuv/gojsonschema.v0"
)
// Request represents a test request
type Request struct {
Path string
Query map[string]string
Headers map[string][]string
Method string
Body string
FormData map[string]FormData
URLEncoded map[string]string
Binary string
JSON interface{}
Response Response
}
func (req *Request) buildBody(test *Test) (io.Reader, string) {
body := &bytes.Buffer{}
contentType := "application/html"
if len(req.Body) > 0 {
body.WriteString(req.Body)
} else if len(req.FormData) > 0 {
dirname := filepath.Dir(test.Filename)
writer := multipart.NewWriter(body)
defer writer.Close()
for name, data := range req.FormData {
if len(data.Value) > 0 {
writer.WriteField(name, data.Value)
} else if len(data.JSON) > 0 {
s, err := json.Marshal(data.JSON)
if err != nil {
// do something better here
continue
}
writer.WriteField(name, string(s[:]))
} else if len(data.Filename) > 0 {
part, _ := writer.CreateFormFile(name, filepath.Base(data.Filename))
// this sucks and should be cleaned up...
filename := filepath.Join(dirname, data.Value)
file, _ := os.Open(filename)
defer file.Close()
io.Copy(part, file)
}
}
contentType = writer.FormDataContentType()
} else if len(req.URLEncoded) > 0 {
values := url.Values{}
for name, value := range req.URLEncoded {
values.Set(name, value)
}
body.WriteString(values.Encode())
contentType = "application/x-www-form-urlencoded"
} else if len(req.Binary) > 0 {
dirname := filepath.Dir(test.Filename)
filename := filepath.Join(dirname, req.Binary)
fileType := mime.TypeByExtension(filepath.Ext(req.Binary))
if len(fileType) > 0 {
contentType = fileType
} else {
contentType = "application/octet-stream"
}
// This is gross and needs to be cleaned up too..
file, _ := os.Open(filename)
defer file.Close()
io.Copy(body, file)
} else if req.JSON != nil {
data, _ := json.Marshal(req.JSON)
body.WriteString(string(data))
contentType = "application/json"
} else {
body.WriteString("")
contentType = ""
}
return body, contentType
}
func (req *Request) buildRequest(test *Test, url *url.URL) (*http.Request, error) {
var testURL = *url
reqURL, err := url.Parse(req.Path)
if err != nil {
return nil, err
}
testURL.Path = reqURL.Path
testURL.RawQuery = reqURL.RawQuery
// build the query string
query := testURL.Query()
for qname, qvalue := range req.Query {
query.Set(qname, qvalue)
}
testURL.RawQuery = query.Encode()
body, contentType := req.buildBody(test)
hreq, err := http.NewRequest(req.Method, testURL.String(), body)
if err != nil {
return nil, err
}
// set the Content-Type we determined when building the body
if len(contentType) > 0 {
hreq.Header.Add("Content-Type", contentType)
}
for name, values := range req.Headers {
for _, value := range values {
hreq.Header.Add(name, value)
}
}
return hreq, nil
}
func (req *Request) compareResponse(hresp *http.Response) (bool, error) {
// first check the resp code
if hresp.StatusCode != req.Response.StatusCode {
err := fmt.Errorf(
"Expected status code '%d' but got '%d'",
req.Response.StatusCode,
hresp.StatusCode,
)
return false, err
}
// check the headers
for name, value := range req.Response.Headers {
actual := hresp.Header[http.CanonicalHeaderKey(name)]
if actual != nil {
if !reflect.DeepEqual(actual, value) {
err := fmt.Errorf(
"Expected header '%s' to be '%#v' but received '%#v'",
name,
value,
actual,
)
return false, err
}
}
}
// read the body into a byte array
bodyByte, err := ioutil.ReadAll(hresp.Body)
if err != nil {
return false, err
}
// check if we have a file block and check it
if len(req.Response.File) > 0 {
if sha256Digest, ok := req.Response.File["sha256"]; ok {
sum := sha256.Sum256(bodyByte)
actual := fmt.Sprintf("%x", sum)
if actual != sha256Digest {
err := fmt.Errorf(
"Expected file to have a SHA256 of '%s' but received '%s'",
sha256Digest,
actual,
)
return false, err
}
}
if md5Digest, ok := req.Response.File["md5"]; ok {
sum := md5.Sum(bodyByte)
actual := fmt.Sprintf("%x", sum)
if actual != md5Digest {
err := fmt.Errorf(
"Expected file to have a MD5 of '%s' but received '%s'",
md5Digest,
actual,
)
return false, err
}
}
if filename, ok := req.Response.File["filename"]; ok {
file, err := os.Create(filename)
if err != nil {
return false, err
}
_, err = file.Write(bodyByte)
if err != nil {
return false, err
}
}
} else {
body := string(bodyByte[:])
// check if we have a json schema to validate
if len(req.Response.JSONSchema) != 0 {
schemaLoader := gojsonschema.NewGoLoader(req.Response.JSONSchema)
bodyLoader := gojsonschema.NewStringLoader(body)
schemaResult, err := gojsonschema.Validate(schemaLoader, bodyLoader)
if err != nil {
return false, err
}
if schemaResult.Valid() != true {
errors := []string{}
for _, err := range schemaResult.Errors() {
errors = append(errors, err.Description())
}
errMsg := strings.Join(errors, "\n")
err := fmt.Errorf("Schema validation failed:\n%s", errMsg)
return false, err
}
}
}
return true, nil
}