3.1 实现 Server v0.2

package main

import (
    "fmt"
    "net"
    "sync"

    "github.com/bwmarrin/snowflake"
    "github.com/huoyijie/GoChat/lib"
)

// 封装客户端连接,增加 snowflake.ID
type socket struct {
    id   snowflake.ID
    conn net.Conn
}

// 存储当前所有客户端连接
var sockets = make(map[snowflake.ID]*socket)

// 多个协程并发读写 sockets 时,需要使用读写锁
var lock sync.RWMutex

// 写锁
func wSockets(wSockets func()) {
    lock.Lock()
    defer lock.Unlock()
    wSockets()
}

// 读锁
func rSockets(rSockets func()) {
    lock.RLock()
    defer lock.RUnlock()
    rSockets()
}

func main() {
    // 启动单独协程,监听 ctrl+c 或 kill 信号,收到信号结束进程
    go lib.SignalHandler()

    // tcp 监听地址 0.0.0.0:8888
    addr := ":8888"

    // tcp 监听
    ln, err := net.Listen("tcp", addr)

    // tcp 监听遇到错误退出进程
    lib.FatalNotNil(err)
    // 输出日志
    lib.LogMessage("Listening on", addr)

    // 创建 snowflake Node
    node, err := snowflake.NewNode(1)
    lib.FatalNotNil(err)

    // 循环接受客户端连接
    for {
        // 每当有客户端连接时,ln.Accept 会返回新的连接 conn
        conn, err := ln.Accept()
        // 如果接受的新连接遇到错误,则退出进程
        lib.FatalNotNil(err)

        // 生成新 ID
        id := node.Generate()
        // 保存新连接
        wSockets(func() {
            sockets[id] = &socket{id, conn}
        })

        // 为每个客户端启动一个协程,读取客户端发送的消息并转发
        go lib.HandleConnection(
            conn,
            id,
            func(msg string) {
                rSockets(func() {
                    for k, v := range sockets {
                        // 向其他所有客户端(除了自己)转发消息
                        if k != id {
                            fmt.Fprintf(v.conn, "%d:%s\r\n", id, msg)
                        }
                    }
                })
            },
            func() {
                // 从当前方法返回时,关闭连接
                conn.Close()
                // 删除连接
                wSockets(func() {
                    delete(sockets, id)
                })
            })
    }
}

server 要把某个 client 发送的消息转发给其他所有 client,需要存储所有的连接,同时给所有连接分配一个 ID,以区分不同的 client。

// 封装客户端连接,增加 snowflake.ID
type socket struct {
    id   snowflake.ID
    conn net.Conn
}

ID 采用 Twitter snowflake 算法生成,需要导入 bwmarrin/snowflake 库

import (
    "fmt"
    "net"
    "sync"

+    "github.com/bwmarrin/snowflake"
    "github.com/huoyijie/GoChat/lib"
)

导入外部库后,需要进入项目根目录 GoChat,然后执行 go mod tidy,会自动下载 bwmarrin/snowflake 库。

// 存储当前所有客户端连接
var sockets = make(map[snowflake.ID]*socket)

// 多个协程并发读写 sockets 时,需要使用读写锁
var lock sync.RWMutex

所有封装好的连接,放入 map 类型变量 sockets 里。当有新连接建立时写入 map,有连接断开时从 map 中删除,转发消息时需要遍历读取所有客户端连接。这些对 sockets 变量的读写操作发生在不同的协程里面,为了避免多个协程因并发读写 sockets 共享内存而造成错误,需要读写锁进行同步。

// 写锁
func wSockets(wSockets func()) {
    lock.Lock()
    defer lock.Unlock()
    wSockets()
}

// 读锁
func rSockets(rSockets func()) {
    lock.RLock()
    defer lock.RUnlock()
    rSockets()
}

上面分别是写锁定和读锁定方法,写锁定后,禁止其他所有对 sockets 变量的读写操作。读锁定后,禁止其他所有对 sockets 变量的写操作,但允许其他所有读操作。锁定后要进行的操作,通过 func 类型参数传入。

1.新连接建立后要保存

// 通过 snowflake 算法生成新 ID
id := node.Generate()
// 保存新连接
wSockets(func() {
    sockets[id] = &socket{id, conn}
})

2.断开连接后要删除

// 删除连接
wSockets(func() {
    delete(sockets, id)
})

3.转发客户端消息

rSockets(func() {
    for k, v := range sockets {
        // 向其他所有客户端(除了自己)转发消息
        if k != id {
            fmt.Fprintf(v.conn, "%d:%s\r\n", id, msg)
        }
    }
})

lib.HandleConnection 函数现在需要传入 2 个 func 类型参数,分别是 handleMsg 和 close。前者是消息处理回调函数,后者是连接关闭回调函数。