GO代理服务器例程——字节流读取、多线程并发、并发安全、waitgroup | 青训营笔记

57 阅读5分钟

这是我参与「第五届青训营 」伴学笔记创作活动的第3天

一 GO代理服务器socks5例程

package main

import (
	"bufio"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
)

const socks5Ver = 0x05
const cmdBind = 0x01
const atypIPV4 = 0x01
const atypeHOST = 0x03
const atypeIPV6 = 0x04

func main() {
	server, err := net.Listen("tcp", "127.0.0.1:1080")
	if err != nil {
		panic(err)
	}
	for {
		client, err := server.Accept()
		if err != nil {
			log.Printf("Accept failed %v", err)
			continue
		}
		go process(client)
	}
}

func process(conn net.Conn) {
	defer conn.Close()
	reader := bufio.NewReader(conn) 
	err := auth(reader, conn)
	if err != nil {
		log.Printf("client %v auth failed:%v", conn.RemoteAddr(), err)
		return
	}
	err = connect(reader, conn)
	if err != nil {
		log.Printf("client %v auth failed:%v", conn.RemoteAddr(), err)
		return
	}
}

func auth(reader *bufio.Reader, conn net.Conn) (err error) {
	// +----+----------+----------+
	// |VER | NMETHODS | METHODS  |
	// +----+----------+----------+
	// | 1  |    1     | 1 to 255 |
	// +----+----------+----------+
	// VER: 协议版本,socks5为0x05
	// NMETHODS: 支持认证的方法数量
	// METHODS: 对应NMETHODS,NMETHODS的值为多少,METHODS就有多少个字节。RFC预定义了一些值的含义,内容如下:
	// X’00’ NO AUTHENTICATION REQUIRED
	// X’02’ USERNAME/PASSWORD

	ver, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("read ver failed:%w", err)
	}
	if ver != socks5Ver {
		return fmt.Errorf("not supported ver:%v", ver)
	}
	methodSize, err := reader.ReadByte()
	if err != nil {
		return fmt.Errorf("read methodSize failed:%w", err)
	}
	method := make([]byte, methodSize)
	_, err = io.ReadFull(reader, method)
	if err != nil {
		return fmt.Errorf("read method failed:%w", err)
	}

	// +----+--------+
	// |VER | METHOD |
	// +----+--------+
	// | 1  |   1    |
	// +----+--------+
	_, err = conn.Write([]byte{socks5Ver, 0x00})
	if err != nil {
		return fmt.Errorf("write failed:%w", err)
	}
	return nil
}

func connect(reader *bufio.Reader, conn net.Conn) (err error) {
	// +----+-----+-------+------+----------+----------+
	// |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
	// +----+-----+-------+------+----------+----------+
	// | 1  |  1  | X'00' |  1   | Variable |    2     |
	// +----+-----+-------+------+----------+----------+
	// VER 版本号,socks5的值为0x05
	// CMD 0x01表示CONNECT请求
	// RSV 保留字段,值为0x00
	// ATYP 目标地址类型,DST.ADDR的数据对应这个字段的类型。
	//   0x01表示IPv4地址,DST.ADDR为4个字节
	//   0x03表示域名,DST.ADDR是一个可变长度的域名
	// DST.ADDR 一个可变长度的值
	// DST.PORT 目标端口,固定2个字节

	buf := make([]byte, 4)
	_, err = io.ReadFull(reader, buf)
	if err != nil {
		return fmt.Errorf("read header failed:%w", err)
	}
	ver, cmd, atyp := buf[0], buf[1], buf[3]
	if ver != socks5Ver {
		return fmt.Errorf("not supported ver:%v", ver)
	}
	if cmd != cmdBind {
		return fmt.Errorf("not supported cmd:%v", ver)
	}
	addr := ""
	switch atyp {
	case atypIPV4:
		_, err = io.ReadFull(reader, buf)
		if err != nil {
			return fmt.Errorf("read atyp failed:%w", err)
		}
		addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3])
	case atypeHOST:
		hostSize, err := reader.ReadByte()
		if err != nil {
			return fmt.Errorf("read hostSize failed:%w", err)
		}
		host := make([]byte, hostSize)
		_, err = io.ReadFull(reader, host)
		if err != nil {
			return fmt.Errorf("read host failed:%w", err)
		}
		addr = string(host)
	case atypeIPV6:
		return errors.New("IPv6: no supported yet")
	default:
		return errors.New("invalid atyp")
	}
	_, err = io.ReadFull(reader, buf[:2])
	if err != nil {
		return fmt.Errorf("read port failed:%w", err)
	}
	port := binary.BigEndian.Uint16(buf[:2])

	dest, err := net.Dial("tcp", fmt.Sprintf("%v:%v", addr, port))
	if err != nil {
		return fmt.Errorf("dial dst failed:%w", err)
	}
	defer dest.Close()
	log.Println("dial", addr, port)

	// +----+-----+-------+------+----------+----------+
	// |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
	// +----+-----+-------+------+----------+----------+
	// | 1  |  1  | X'00' |  1   | Variable |    2     |
	// +----+-----+-------+------+----------+----------+
	// VER socks版本,这里为0x05
	// REP Relay field,内容取值如下 X’00’ succeeded
	// RSV 保留字段
	// ATYPE 地址类型
	// BND.ADDR 服务绑定的地址
	// BND.PORT 服务绑定的端口DST.PORT
	_, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
	if err != nil {
		return fmt.Errorf("write failed: %w", err)
	}
	ctx, cancel := context.WithCancel(context.Background()) //创建上下文,防止提前结束
	defer cancel()

	go func() {
		_, _ = io.Copy(dest, reader) //复制函数
		cancel() //子线程执行完成才执行这个,才能结束
	}()
	go func() {
		_, _ = io.Copy(conn, dest)
		cancel()//子线程执行完成才执行这个,才能结束
	}()

	<-ctx.Done() //检查cancel()完成,可结束
	return nil
}

上面代码的疑惑之读取字节流的方法:

1、
reader := bufio.NewReader(conn) //创建一个NewReader,reader是*bufio.Reader类型,此类有很多方法
reader.ReadByte() //取出1个字节,是弹出,再读取则会读取下一个。
2、
method := make([]byte, methodSize)
_, err = io.ReadFull(reader, method) //用methodSize个字节,装满method,装满就停。也可以用reader.ReadBytes(methodSize)
3、
_, err = io.ReadFull(reader, buf[:2]) //reader只剩最后两位了,取出来放在buf[0,1]位置
if err != nil { 
return fmt.Errorf("read port failed:%w", err) 
} 
port := binary.BigEndian.Uint16(buf[:2])

write和reader处理需要转换成流,字符串转成流方法如下:

4、
strings.NewReader("ABCDEFG")

二 GO多线程并发示例:

这里可以使用到标准库里面的一个context 机制,用 context 连 withcancel 来创建一个context。只要在最后等待ctx.Done(cancel 被调用, ctx.Done就会立刻返回。然后在上面的两个goroutinue 里面 调用一次 cancel 即可。

ctx, cancel := context.WithCancel(context.Background()) //创建上下文,防止提前结束
defer cancel()

go func() {
	_, _ = io.Copy(dest, reader) //复制函数
	cancel() //子线程执行完成才执行这个,才能结束
}()
go func() {
	_, _ = io.Copy(conn, dest)
	cancel()//子线程执行完成才执行这个,才能结束
}()

<-ctx.Done() //检查cancel完成,可结束
return nil

三 并发中的通道阻塞和等待:

package main

  


import (

    //"bytes"

    "fmt"

    "time"

)

  


func main() {

//var src [10]int

src := make(chan int)

dest := make(chan int,3)

go func() {

    defer close(src)

    for i := 0; i< 10; i++{

    src<- i+2

    //fmt.Println(src[i])

    //time.Sleep(time.Second)

    }

}()

//time.Sleep(time.Second*2)

go func() {

//time.Sleep(time.Second*1)//因为两个协程先后是未知的。src为数组则需要延时,不然这个协程启动时src里元素还没有放进去,会直接传全0。为chan则不需要,会阻塞等待src。

    defer close(dest)

    for i := range src {

//fmt.Println(i)

    dest <- i * i

    //time.Sleep(time.Second)

    }

}()

/*

go func() {

    for i := range dest {

    

    println(i)

    //time.Sleep(time.Second)

    }

}()*/

    //time.Sleep(time.Second*3)//src为数组则需要延时,不然直接主程序跑完就结束了。为chan则不需要,会阻塞等待。

    fmt.Println(src)

//fmt.Println(dest)

for i := range dest {

    fmt.Println(i)

    }

}

输出

4
9
16
25
36
49
64
81
100
121

注意:

  1. 必须有defer close 不然运行会出现关闭错误。
  2. //time.Sleep(time.Second*1)//因为两个协程先后是未知的。src为数组则需要延时,不然这个协程启动时src里元素还没有放进去,会直接传全0。为chan类型则不需要,非缓冲通道会阻塞等待src。
  3. //time.Sleep(time.Second*3)//src为数组则需要延时,不然直接主程序跑完就结束了。为chan则不需要,会阻塞等待。
  4. books.studygolang.com/gopl-zh/ch8… 详细的通道机制看这个文章。

四 并发锁

image.png

若已经被锁,则不会往下操作,需要等待解锁。然后检测解锁后,下一个线程才会加锁,然后操作,再解锁。 如果不加锁则会引发重复写冲突。

五 协程阻塞控制waitgroup

可以使用waitgroup来控制协程的阻塞,避免协程未完成就跳过。 add方法协程数+1;done方法可以完成一个协程,协程-1;wait是等待所有协程完成的阻塞,需要等待协程数量减至0才会通过。如下图所示 image.png