Go httptest 包在单元测试中使用

1,214 阅读1分钟

「这是我参与2022首次更文挑战的第9天,活动详情查看:2022首次更文挑战

golang 中有个 httptest 包可以用来测试 http 接口

httptest 方法介绍

NewRequest

NewRequest 方法用来创建一个 http 的请求体。

方法说明:

func NewRequest(method, target string, body io.Reader) *http.Request
  • method 参数表示测试的接口的 HTTP 方法。
  • target 参数表示接口定义的路由。
  • body 参数表示请求体。

NewRecorder(响应体)

func NewRecorder() *ResponseRecorder

NewRecorder 方法用来创建 http 的响应体。返回的类型是 *httptest.ResponseRecorder ,包含接口返回信息,等价于 http.ResponseWriter

看个 http get 例子:

// Req: http://localhost:1234/upper?word=abc
// Res: ABC
func TestUpperCaseHandler(t *testing.T) {
	req := httptest.NewRequest(http.MethodGet, "/upper?word=abc", nil)
	w := httptest.NewRecorder()
	UpperCaseHandler(w, req)
	res := w.Result()
	defer res.Body.Close()
	data, err := ioutil.ReadAll(res.Body)
	if err != nil {
		t.Errorf("expected error to be nil got %v", err)
	}
	if string(data) != "ABC" {
		t.Errorf("expected ABC got %v", string(data))
	}
}
func UpperCaseHandler(w http.ResponseWriter, r *http.Request) {
	query, err := url.ParseQuery(r.URL.RawQuery)
	if err != nil {
		w.WriteHeader(http.StatusBadRequest)
		fmt.Fprintf(w, "invalid request")
		return
	}
	word := query.Get("word")
	if len(word) == 0 {
		w.WriteHeader(http.StatusBadRequest)
		fmt.Fprintf(w, "missing word")
		return
	}
	w.WriteHeader(http.StatusOK)
	fmt.Fprintf(w, strings.ToUpper(word))
}

运行结果如下:

=== RUN   TestUpperCaseHandler
--- PASS: TestUpperCaseHandler (0.00s)
PASS

再看一个 post 请求 例子 :

// 业务代码
func UpperCaseHandle1(w http.ResponseWriter, r *http.Request) {
	body, _ := ioutil.ReadAll(r.Body)
	fmt.Println(string(body))

	str := string(body)

	mp := make(map[string]string)
	err := json.Unmarshal([]byte(str), &mp)
	if err != nil {
		fmt.Println(err)
	}
	word := mp["params"]
	if len(word) == 0 {
		w.WriteHeader(http.StatusBadRequest)
		fmt.Fprintf(w, "missing word")
		return
	}
	w.WriteHeader(http.StatusOK)
	fmt.Fprintf(w, strings.ToUpper(word))
}
func Test_testApi(t *testing.T) {
	tests := []struct {
		name string
	}{
		{
			name: "test api",
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			ts := httptest.NewServer(http.HandlerFunc(UpperCaseHandle1))
			defer ts.Close()
			params := make(map[string]string)
			params["params"] = "paramsBody"
			paramsByte, _ := json.Marshal(params)
			resp, err := http.Post(ts.URL, "application/json", bytes.NewBuffer(paramsByte))
			if err != nil {
				t.Error(err)
			}
			defer resp.Body.Close()

			t.Log(resp.StatusCode)
			if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
				body, _ := ioutil.ReadAll(resp.Body)
				t.Error(string(body))
			}
		})
	}
}

执行结果:

=== RUN   Test_testApi
--- PASS: Test_testApi (208.72s)
=== RUN   Test_testApi/test_api
{"params":"paramsBody"}
    http_close_resp_test.go:107: 200
    --- PASS: Test_testApi/test_api (208.72s)
PASS

参考资料