golang 实战 简易即时通讯服务(TCP)
首先,我们对整个服务进行简单划分: 通讯模组;消息储存模组;连接与连接之间的逻辑关系
通讯模组
通信流程
- 客户端->注册发信服务连接请求->服务端
- 客户端<-返回生成的客户端请求id+房间号<-服务端
- 客户端->注册收信服务连接请求->服务端
- 客户端<->并发收信发信<->服务端
- (客户端,服务端)任意一端程序停止,断开连接
报文组成(总报文=操作码+分报文)
- 总报文=[操作码 1 byte]+[正文]
- 注册收信报文=[客户端id 3 byte]+[房间号 2 byte]
- 注册发信报文=[房间号 2 byte]
- 消息响应报文=[消息id 1 byte]+[房间号 2 byte]+[消息来源客户端id 3 byte]
- 消息通知报文=[消息数量 1 byte]
- 消息发送报文=[正文]
消息储存模组
这里我们为了方便,使用的sqlite3作为数据库(驱动使用的是 github.com/mattn/go-sqlite3),当然也可以自行选者mysql,postgre, oracle或者你喜欢用的数据库。 如果你使用的也是sqlite,这里可能有个点需要注意。sqlite由于同一时间只能一写多读的特性,所以使用时需要加锁,否则多线程执行非查询操作时会出现database is locked的错误
数据储存层代码(本演示全部数据库操作都经过此层)
func writeTodatabase(mod uint8, query string, args ...any) (any, error) {
defer fmt.Printf("write to database finished,mod %d\n", mod)
fmt.Println("start exec sql...")
db, err := sql.Open("sqlite3", "data.db")
if err == nil {
defer db.Close()
//如果是查询,加读锁。如果是写入或者修改,加写锁
if mod == QUERY {
dbmutex.RLock()
defer dbmutex.RUnlock()
return db.QueryContext(context.Background(), query, args...)
} else if mod == QUERYROW {
dbmutex.RLock()
defer dbmutex.RUnlock()
return db.QueryRowContext(context.Background(), query, args...), nil
} else if mod == EXEC {
dbmutex.Lock()
defer dbmutex.Unlock()
return db.ExecContext(context.Background(), query, args...)
}
}
return nil, err
}
业务逻辑关系
type config struct {//整个通讯系统的运行信息
Port int `yaml:"Port"`
}
type client struct {//用户连接实例
roomid uint32//此客户所属房间id
last_reply_time time.Time
last_verion uint32//此客户端最新消息版本
mutex sync.RWMutex
response_writer *response_Writer
}
type response_Writer struct {//用户收信实例,必须注册了收信实例才能接受信息
con net.Conn
done chan struct{}
}
type room struct {//房间
latestversion uint32
mutex sync.RWMutex
}
type message struct {//消息实例
id int//消息id,自增且唯一
roomid, sendid uint32 //msgid,roomid,send userid
content string
}
room和client的关系 room实例本身就是一个房间最新消息版本号+读写锁。客户端读取信息只需要用room的锁上个读锁,获得最新消息后再解锁。写消息也一样(上写锁)。版本号由于向上增长,所以客户端版本号低于房间版本号就会拉取两个版本号之间的信息。 [此处图片]
上代码
前置任务
const (
TIMEOUTMAX = time.Duration(1000) * time.Millisecond
MSG_SYN = 10
MSG_ACK = 11
MSG_CTR = 12
)
const (
DATA_OUT_OF_RANGE = "data out of range"
)
const (
QUERY = 90
QUERYROW = 91
EXEC = 92
)
var (
cnf config
errorlog = log.New(os.Stderr, "[error]", log.Ltime|log.Llongfile)//调试利器,快速定位到出错点
debuglog = log.New(os.Stdout, "[debug]", log.Ltime|log.Llongfile)
clientlist = make(map[uint32]*client)
roomlist = make(map[uint32]*room)
dbmutex sync.RWMutex //database mutex
)
func init() {
flag.IntVar(&cnf.Port, "Port", 9000, "--cnf.Port xxx")
flag.Parse()
}
主函数
func main() {
if cnf.Port < 0 {
errorlog.Println("Port", cnf.Port, " is invalid")
os.Exit(1)
}
rand.Seed(time.Now().UnixNano()) //set rand seed
listener, err := net.Listen("tcp", ":"+strconv.Itoa(cnf.Port))
if err == nil {
var con net.Conn
for {
con, err = listener.Accept()
if err == nil {
go readconse(con)
} else {
errorlog.Println(err.Error())
}
}
} else {
errorlog.Println(err.Error())
}
}
连接处理
func readconse(con net.Conn) {
defer con.Close()
buffer := make([]byte, 1024)
var (
lang int
err error
resp []byte
id uint32
cli *client
is_accept bool = false//是否为收信连接
)
defer func() {
if !is_accept {
debuglog.Println("start deregister main", id, "process")
} else {
debuglog.Println("start deregister accept", id, "process")
}
if !is_accept && cli != nil && cli.response_writer != nil {
debuglog.Println(id, "send end signal to accept process")
cli.response_writer.done <- struct{}{}
}
if _, ok := clientlist[id]; ok {
if !is_accept {
debuglog.Println(id, "send server deregistered")
delete(clientlist, id)
} else {
debuglog.Println(id, "accept server deregistered")
}
}
}()
for {
lang, err = con.Read(buffer)
if err == nil {
switch buffer[0] {
case MSG_SYN:
if lang == 4 { //注册客户端收信连接
is_accept = true
id = uint32(buffer[1])*256*256 + uint32(buffer[2])*256 + uint32(buffer[3])
var ok bool
if cli, ok = clientlist[id]; ok {
debuglog.Println("register client", id, "accept server")
cli.response_writer = &response_Writer{con: con, done: make(chan struct{})}
var msgarr []message
tck := time.NewTicker(1 * time.Second)//每秒进行一波版本比较,可自行根据需求更改间隔时间
for {
select {
case <-cli.response_writer.done:
debuglog.Println(id, "accept done signal")
return
case <-tck.C:
roomlist[cli.roomid].mutex.RLock()
if cli.last_verion < roomlist[cli.roomid].latestversion {
debuglog.Printf("client %v find new message", id)
roomlist[cli.roomid].mutex.RUnlock()
msgarr = getmessage(cli.last_verion, cli.roomid)
debuglog.Printf("client %v find %d new messages", id, len(msgarr))
if len(msgarr) > 0 {
err = sendtocli(&msgarr, con)
if err != nil {
errorlog.Printf("cli accept server %v closed by client", id)
return
} else {
cli.mutex.Lock()
cli.last_verion = roomlist[cli.roomid].latestversion
cli.mutex.Unlock()
debuglog.Printf("client %v latest version update finished", id)
}
}
} else {
roomlist[cli.roomid].mutex.RUnlock()
}
}
}
}
} else if lang == 3 { //注册客户端发信连接
id = rand.Uint32() % (256*256*256 + 256*256 + 256)
resp = make([]byte, 5)
resp[0] = byte((id / (256 * 256)) % 256)
resp[1] = byte((id / 256) % 256)
resp[2] = byte(id % 256)
copy(resp[3:5], buffer[1:3])
_, err = con.Write(resp)
if err != nil {
debuglog.Println("connection closed by client")
break
}
clientlist[id] = &client{roomid: uint32(buffer[1])*256 + uint32(buffer[2]), last_reply_time: time.Now()}
cli = clientlist[id]
if _, ok := roomlist[cli.roomid]; !ok {
roomlist[cli.roomid] = &room{}
debuglog.Println("create new room", cli.roomid)
}
debuglog.Printf("register client %v send server", id)
}
case MSG_ACK: //客户端向服务端发送信息
if !writetomessage(buffer[1:lang], cli.roomid, id) {
err = cli.write(MSG_CTR, []byte("send message failed unknown error"))
if err != nil {
errorlog.Println(err.Error())
break
}
}
case MSG_CTR: //切换房间等操作
}
} else {
debuglog.Println("client", id, "closed connection")
return
}
}
}
获取最新消息发往客户端
func getmessage(version, roomid uint32) []message {
var msgarr = []message{}
var msg *message
debuglog.Println("room list read locked")
roomlist[roomid].mutex.RLock()//挂房间读锁,函数结束时会自动解锁
defer roomlist[roomid].mutex.RUnlock()
defer fmt.Println("room list read unlocked")
rwsany, err := writeTodatabase(QUERY, "select id,sendid,content from `"+strconv.FormatUint(uint64(roomid), 10)+"` where id>?", version)
if err == nil {
debuglog.Println("write to database finished")
rws := rwsany.(*sql.Rows)
for rws.Next() {
debuglog.Println("read message...")
msgarr = append(msgarr, message{roomid: roomid})
msg = &msgarr[len(msgarr)-1]
err = rws.Scan(&msg.id, &msg.sendid, &msg.content)
if err != nil {
debuglog.Printf("select id,sendid,content from `"+strconv.FormatUint(uint64(roomid), 10)+"` where id>%v", version)
errorlog.Printf("read room %v msg failed,version %v, err %s", roomid, version, err.Error())
return msgarr
}
}
}
return msgarr
}
func sendtocli(src *[]message, con net.Conn) error {
if len(*src) == 0 {
return nil
}
_, err := con.Write([]byte{MSG_SYN, byte(len(*src))})
if err == nil {
debuglog.Println("prepare send", len(*src))
for _, ele := range *src {
time.Sleep(10 * time.Millisecond) //sleep 10 milliseconds
_, err = con.Write(ele.tobytes())
if err != nil {
return err
} else {
debuglog.Println("send message", string(ele.tobytes()[7:]))
}
}
}
return err
}
我们可能注意到发往客户端时有10毫秒休眠间隔,这个休眠间隔可以设置的小,但不能没有。计算机处理速度很快的现在,别看这里循环有那么几行代码,在计算机面前,嗖的一下就全没了。然后由于太快了,最后接收端缓冲区都还没读,我们就把东西全给别人塞进去了,最后几个独立的包可能被粘成一个了。
数据量大的时候又有另一个坑,接收端缓冲区设置的也许是1MB,但是我们发送端发的包是800kb,连发四包,最后接收端又读出问题,我们第二个包可能被接收端读出来就是身首异处。我们第二个包自己设置的控制层直接被读到第一个包末尾处了,第二个包我们解析的时候数据就是有问题的
消息结构体转byte数组
// message = id[1]+roomid[2]+sendcliid[3]+content
func (s *message) tobytes() []byte {
ans := make([]byte, 7+len(s.content))
ans[0] = MSG_ACK
ans[1] = byte(s.id)
ans[2] = byte((s.roomid / 256) % 256)
ans[3] = byte(s.roomid % 256)
ans[4] = byte((s.sendid / (256 * 256)) % 256)
ans[5] = byte((s.sendid / 256) % 256)
ans[6] = byte(s.sendid % 256)
copy(ans[7:], []byte(s.content))
return ans
}
func (s *client) write(code uint8, v []byte) error {
ans := make([]byte, 1+len(v))
ans[0] = code
copy(ans[1:], v)
if s.response_writer == nil {
return fmt.Errorf("not set response writer")
}
_, err := s.response_writer.con.Write(ans)
return err
}
客户端储存消息至房间数据库,并更新消息版本号
func writetomessage(message []byte, roomid, sendid uint32) bool {
debuglog.Printf("client %v locked room %v", sendid, roomid)
roomlist[roomid].mutex.Lock()
defer roomlist[roomid].mutex.Unlock()
defer debuglog.Printf("client %v unlocked room %v", sendid, roomid)
_, err := writeTodatabase(EXEC, "insert into `"+strconv.FormatUint(uint64(roomid), 10)+"` (sendid,content)values(?,?)", sendid, string(message))
if err == nil {
debuglog.Printf("insert into `%v` (sendid,content)values(%v,'%v')", roomid, sendid, string(message))
rowany, _ := writeTodatabase(QUERYROW, "select MAX(id) from `"+strconv.FormatUint(uint64(roomid), 10)+"`")//我们这里上的写锁,所以最大消息id就是刚才我们插入的那条消息
var maxversion uint32
row := rowany.(*sql.Row)
if row.Scan(&maxversion) == nil {
roomlist[roomid].latestversion = maxversion
clientlist[sendid].last_verion = maxversion
} else {
errorlog.Println("get maxversion failed")
}
return true
} else {//数据表不存在,创建数据表
errorlog.Println("insert data error", err.Error())
_, err = writeTodatabase(EXEC, "create table `"+strconv.FormatUint(uint64(roomid), 10)+"` (id INTEGER PRIMARY KEY autoincrement NOT NULL,sendid INTEGER NOT NULL,content VARCHAR(1000))")
if err == nil {
_, err = writeTodatabase(EXEC, "insert into `"+strconv.FormatUint(uint64(roomid), 10)+"` (sendid,content)values(?,?)", sendid, string(message))
if err == nil {
rowany, _ := writeTodatabase(QUERYROW, "select MAX(id) from `"+strconv.FormatUint(uint64(roomid), 10)+"`")
var maxversion uint32
row := rowany.(*sql.Row)
if row.Scan(&maxversion) == nil {
roomlist[roomid].latestversion = maxversion
} else {
errorlog.Println("get maxversion failed")
}
return true
} else {
errorlog.Println(err.Error())
}
} else {
errorlog.Printf("insert data to room %v failed,err=%s", roomid, err.Error())
}
}
return false
}
实机演示
完整代码地址github.com/oswaldoooo/…