hub
package hub
import (
"fmt"
"github.com/gorilla/websocket"
"github.com/zeromicro/go-zero/core/logx"
"sync"
"time"
)
// 心跳间隔
var heartbeatInterval = 25 * time.Second
type Hub struct {
//连接列表
// -----user-----
// Registered clients.
Clint2ConnMap map[string]*ClientUser
// Inbound messages from the clients.
broadcast_user chan SendMsg
// Register requests from the clients.
registe_user chan *ClientUser
// Unregister requests from clients.
unregister_user chan *ClientUser
ToClientChan chan SendMsg
Client2ConnMu sync.RWMutex
}
func Start() *Hub {
hub := NewHub()
go hub.Run()
go hub.WriteMessage()
go hub.PingTimer()
return hub
}
// 初始化变量
func NewHub() *Hub {
return &Hub{
broadcast_user: make(chan SendMsg, 100),
registe_user: make(chan *ClientUser),
unregister_user: make(chan *ClientUser),
Clint2ConnMap: make(map[string]*ClientUser),
Client2ConnMu: sync.RWMutex{},
ToClientChan: make(chan SendMsg, 100),
}
}
// 客户端数量
func (h *Hub) ClientNumber() int {
h.Client2ConnMu.RLock()
defer h.Client2ConnMu.RUnlock()
return len(h.Clint2ConnMap)
}
// 客户端是否存在
func (h *Hub) IsAlive(clientId string) (conn *ClientUser, ok bool) {
h.Client2ConnMu.RLock()
defer h.Client2ConnMu.RUnlock()
conn, ok = h.Clint2ConnMap[clientId]
return
}
// 删除客户端
func (h *Hub) DelClient(clientId string) {
h.Client2ConnMu.Lock()
defer h.Client2ConnMu.Unlock()
delete(h.Clint2ConnMap, clientId)
}
func (h *Hub) GetClientList() map[string]*ClientUser {
h.Client2ConnMu.RLock()
defer h.Client2ConnMu.RUnlock()
return h.Clint2ConnMap
}
func (h *Hub) Run() {
for {
select {
// -----user-----
case client_user := <-h.registe_user:
h.Clint2ConnMap[client_user.ID] = client_user
logx.Info("register", h.ClientNumber())
case clientUn_user := <-h.unregister_user:
logx.Info("unregister", h.ClientNumber())
h.DelClient(clientUn_user.ID)
case message_user := <-h.broadcast_user:
logx.Info("recieve message", h.ClientNumber())
for clientId, client_user := range h.Clint2ConnMap {
message_user.ClientId = clientId
if err := client_user.Conn.WriteJSON(message_user); err != nil {
_ = client_user.Conn.Close()
fmt.Println(err)
return
}
}
// end case
} // end select
}
}
func (h *Hub) MakeSendMsg(clientId, category, subject string, data interface{}) SendMsg {
var msg SendMsg
msg.ClientId = clientId
msg.Category = category
msg.Message.Subject = subject
msg.Message.Data = data
return msg
}
func (h *Hub) SendMessage2Client(s SendMsg) error {
if s.Category == "" {
return fmt.Errorf("category is empty")
}
if s.ClientId == "" {
return fmt.Errorf("ClientId is empty")
}
if s.Message.Subject == "" {
return fmt.Errorf("Subject is empty")
}
h.ToClientChan <- s
return nil
}
func (h *Hub) WriteMessage() {
for {
select {
case sendMsg := <-h.ToClientChan:
if conn, ok := h.IsAlive(sendMsg.ClientId); ok {
if err := conn.Conn.WriteJSON(sendMsg); err != nil {
_ = conn.Conn.Close()
fmt.Println(err)
return
}
}
}
}
}
// 启动定时器进行心跳检测
func (h *Hub) PingTimer() {
go func() {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Printf("总连接数:%d", h.ClientNumber())
fmt.Println("")
//发送心跳
clientList := h.GetClientList()
for clientId, user := range clientList {
if err := user.Conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(10*time.Second)); err != nil {
_ = user.Conn.Close()
h.DelClient(clientId)
fmt.Printf("发送心跳失败: %s 总连接数:%d", clientId, h.ClientNumber())
return
}
}
}
}
}()
}
client
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hub
import (
"fmt"
"net/http"
"time"
// "encoding/json"
// "net/http"
// "time"
"github.com/gorilla/websocket"
"github.com/zeromicro/go-zero/core/logx"
)
const (
// Time allowed to write a message to the peer.
writeWaitManage = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWaitManage = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriodManage = (pongWaitManage * 9) / 10
// Maximum message size allowed from peer.
maxMessageSizeManage = 8192
)
var (
newlineManage = []byte{'\n'}
spaceManage = []byte{' '}
)
var upgraderManage = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// ClientUser is a middleman between the websocket connection and the hub.
type ClientUser struct {
hub *Hub
// The websocket connection.
Conn *websocket.Conn
// Buffered channel of outbound messages.
Send chan SendMsg
// user name. from query name
ID string
}
type SendMsg struct {
ClientId string `json:"clientId"`
Category string `json:"category"`
Message ItemData `json:"message"`
}
type ItemData struct {
Subject string `json:"subject"`
Data interface{} `json:"data"`
}
// readPump pumps messages from the websocket connection to the hub.
//
// The application runs readPump in a per-connection goroutine. The application
// ensures that there is at most one reader on a connection by executing all
// reads from this goroutine.
func (c *ClientUser) readPump() {
//fmt.Println("user start read Pump")
defer func() {
//fmt.Println("user end read Pump")
c.hub.unregister_user <- c
c.Conn.Close()
}()
c.Conn.SetReadLimit(maxMessageSizeManage)
c.Conn.SetReadDeadline(time.Now().Add(pongWaitManage))
c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(pongWaitManage)); return nil })
for {
_, message, err := c.Conn.ReadMessage()
fmt.Println("use recv msg:", string(message), err)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
fmt.Println("err: ReadMessage")
logx.Infof("error: %v", err)
}
break
}
// TODO 处理回调注册
//message = bytes.TrimSpace(bytes.Replace(message, newlineManage, spaceManage, -1))
// if err != nil {
// fmt.Println("read json error")
// logx.Error(err)
// return
// }
}
}
// writePump pumps messages from the hub to the websocket Connection.
//
// A goroutine running writePump is started for each Connection. The
// application ensures that there is at most one writer to a Connection by
// executing all writes from this goroutine.
// serveWs handles websocket requests from the peer.
func ServerWS(id string, hub *Hub, w http.ResponseWriter, r *http.Request) {
// fmt.Println("app ws server")
if len(id) < 1 {
logx.Error("id is empty")
return
}
//
conn, err := upgraderManage.Upgrade(w, r, nil)
if err != nil {
logx.Error(err.Error())
return
}
// uid, _ := strconv.ParseInt(id, 10, 64)
client := &ClientUser{
hub: hub,
Conn: conn,
ID: id,
}
client.hub.registe_user <- client
// Allow collection of memory referenced by the caller by doing all work in
// new goroutines.
go client.readPump()
}
使用
// hub可以放在一些全局能使用的地方
func main() {
port := getPort()
hub := hub.Start()
//初始化路由
routers.Init(hub)
fmt.Printf("服务器启动成功,端口号:%s\n", port)
if err := http.ListenAndServe(":"+port, nil); err != nil {
panic(err)
}
}
func getPort() string {
port := "8080"
args := os.Args //获取用户输入的所有参数
if args != nil && len(args) >= 2 && len(args[1]) != 0 {
port = args[1]
}
return port
}
// 路由
func Init(h *hub.Hub) {
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
hub.ServerWS("唯一标识", h, w, r)
})
}
// 写数据给用户
var msg hub.SendMsg
msg.ClientId = Int64ToMD5(id)
msg.Category = "分类"
msg.Message.Subject = "行为主题"
if err := hub.SendMessage2Client(msg); err != nil {
return err
}
Int64ToMD5(num int64) string {
str := strconv.FormatInt(num, 10) // 将int64转换为字符串
hasher := md5.New()
hasher.Write([]byte(str))
hashBytes := hasher.Sum(nil)
md5Str := hex.EncodeToString(hashBytes)
return md5Str
}