Go并发请求多个API

200 阅读3分钟

公司之前的很多项目都是PHP写的,有些功能接口返回的数据量比较大,所以业务上会按照数据模块拆分成多个内部接口,在主接口中并发调用,最后整理数据结果集后返回。正好最近在学习Go,想着这个并发请求封装类用Go应该怎么封装,正好学习一下。本着学习的想法发上来,有问题的地方欢迎大佬指正,一起学习。

请求库接口代码utils/multi_requester.go

package utils

import (
	"bytes"
	"encoding/json"
	"errors"
	"io"
	"log"
	"net/http"
	"sync"
	"time"
)

type MultiRequester struct {
	client             *http.Client
	requestPool        []*http.Request
	requestBox         map[*http.Request]map[string]interface{}
	requestID          string
	defaultInnerDomain string
	mu                 sync.Mutex
	responseCache      map[*http.Request]map[string]interface{}
}

func NewMultiRequester(timeout time.Duration) *MultiRequester {
	mr := &MultiRequester{
		client:             &http.Client{},
		requestPool:        make([]*http.Request, 0),
		requestBox:         make(map[*http.Request]map[string]interface{}),
		requestID:          "0",
		defaultInnerDomain: "",
		responseCache:      make(map[*http.Request]map[string]interface{}), // 初始化响应缓存
	}

	if timeout.Seconds() > 0 {
		mr.client.Timeout = timeout
	}

	return mr
}

func (mr *MultiRequester) getDefaultDomain() string {
	if mr.defaultInnerDomain == "" {
		// Implement the logic to get the default inner domain here
		// You can use a configuration file or a constant value
		// Example:
		// mr.defaultInnerDomain = "https://example.com"
	}
	return mr.defaultInnerDomain
}

func (mr *MultiRequester) Add(url string, params map[string]interface{}) *http.Request {
	commonParams := mr.getCommonParams()
	for k, v := range params {
		commonParams[k] = v
	}

	jsonParams, _ := json.Marshal(commonParams)
	req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonParams))
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("User-Agent", "your_user_agent")

	mr.requestPool = append(mr.requestPool, req)
	mr.requestBox[req] = map[string]interface{}{"url": url, "params": params}

	return req
}

func (mr *MultiRequester) Exec() {
	var wg sync.WaitGroup

	for _, req := range mr.requestPool {
		wg.Add(1)
		go func(req *http.Request) {
			defer wg.Done()
			resp, err := mr.doRequest(req)

			if err != nil {
				// 处理请求错误
				log.Println("Error executing request:", err)
				return
			}

			// 将响应缓存到 responseCache 中
			mr.mu.Lock()
			mr.responseCache[req] = resp
			mr.mu.Unlock()
		}(req)
	}

	wg.Wait()
}

func (mr *MultiRequester) doRequest(req *http.Request) (map[string]interface{}, error) {
	resp, err := mr.client.Do(req)
	if err != nil {
		// Handle the error
		log.Println("client do err: ", err)
		return nil, err
	}

	defer resp.Body.Close()

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		// Handle the error
		log.Println("client read resp err: ", err)
		return nil, err
	}

	if resp.StatusCode != http.StatusOK {
		// 处理非 200 响应
		return nil, errors.New("Non-200 HTTP status code")
	}

	var responseMap map[string]interface{}
	err = json.Unmarshal(body, &responseMap)
	if err != nil {
		log.Println("client unmarshal err: ", err)
		return nil, err
	}

	return responseMap, nil
}

func (mr *MultiRequester) getCommonParams() map[string]interface{} {
	commonParams := map[string]interface{}{
		"inner_request": 1,
		"request_id":    mr.requestID,
	}
	return commonParams
}

func (mr *MultiRequester) GetContent(req *http.Request) map[string]interface{} {
	// Implement logic to extract content from the response
	// You can use a similar approach as in the "doRequest" method
	// Return the content as a map[string]interface{}
	mr.mu.Lock()
	defer mr.mu.Unlock()

	// 从 responseCache 中获取响应
	if resp, ok := mr.responseCache[req]; ok {
		return resp
	}

	// 如果请求不存在于 responseCache 中,返回 nil
	return nil
}

func (mr *MultiRequester) Cleanup() {
	mr.mu.Lock()
	defer mr.mu.Unlock()

	for req := range mr.requestBox {
		delete(mr.requestBox, req)
	}
}

// You can add additional methods and functionality as needed

测试代码utils/multi_requester_test.go

package utils

import (
	"fmt"
	"testing"
	"time"
)

func TestMultiRequester(t *testing.T) {
	begin := time.Now()

	// 创建一个 MultiRequester 实例,设置超时时间为5秒
	mr := NewMultiRequester(5 * time.Second)

	// 添加请求到 MultiRequester
	params1 := map[string]interface{}{
		"key1": "中国",
		"key2": "美国",
	}
	req1 := mr.Add("http://localhost:8080/endpoint1", params1)

	params2 := map[string]interface{}{
		"key3": "唐朝",
		"key4": "宋朝",
	}
	req2 := mr.Add("http://localhost:8080/endpoint2", params2)

	// 执行所有请求
	mr.Exec()

	// 获取请求1的响应内容
	content1 := mr.GetContent(req1)
	fmt.Println("Response from endpoint 1:", content1)

	// 获取请求2的响应内容
	content2 := mr.GetContent(req2)
	fmt.Println("Response from endpoint 2:", content2)

	// 清理资源
	mr.Cleanup()

	fmt.Printf("\ncost: %dms\n", time.Since(begin).Milliseconds())
}

模拟API接口 api.go

package main

import (
	"encoding/json"
	"fmt"
	"net/http"
	"time"
)

// 定义一个结构体来表示请求的参数
type RequestBody1 struct {
	Key1 string `json:"key1"`
	Key2 string `json:"key2"`
}

type RequestBody2 struct {
	Key3 string `json:"key3"`
	Key4 string `json:"key4"`
}

// 定义一个结构体来表示响应的结果
type ResponseBody struct {
	Message string `json:"message"`
	Name    string `json:"name"`
	City    string `json:"city"`
}

func main() {
	// 创建一个 HTTP 服务器,监听端口 8080
	http.HandleFunc("/endpoint1", handleEndpoint1)
	http.HandleFunc("/endpoint2", handleEndpoint2)

	fmt.Println("Server is listening on :8080...")
	http.ListenAndServe(":8080", nil)
}

// 处理第一个 POST 请求的函数
func handleEndpoint1(w http.ResponseWriter, r *http.Request) {
	// 检查请求方法是否为 POST
	if r.Method != http.MethodPost {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	// 解析请求体中的 JSON 数据
	var reqBody RequestBody1
	err := json.NewDecoder(r.Body).Decode(&reqBody)
	if err != nil {
		http.Error(w, "Invalid request body", http.StatusBadRequest)
		return
	}

	// 延迟1秒
	time.Sleep(1 * time.Second)

	// 处理请求参数并生成响应
	message := fmt.Sprintf("Received request with key1=%s and key2=%s", reqBody.Key1, reqBody.Key2)
	response := ResponseBody{Message: message, Name: reqBody.Key1, City: reqBody.Key2}

	// 将响应转换为 JSON 并返回给客户端
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(response)
}

// 处理第二个 POST 请求的函数
func handleEndpoint2(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	var reqBody RequestBody2
	err := json.NewDecoder(r.Body).Decode(&reqBody)
	if err != nil {
		http.Error(w, "Invalid request body", http.StatusBadRequest)
		return
	}

	// 延迟8秒
	time.Sleep(8 * time.Second)

	// 处理请求参数并生成响应
	message := fmt.Sprintf("Received request with key3=%s and key4=%s at endpoint2", reqBody.Key3, reqBody.Key4)
	response := ResponseBody{Message: message, Name: reqBody.Key3, City: reqBody.Key4}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(response)
}