Golang压测ws

普通版,发完就退出

package main

import (
	"fmt"
	"log"
	"sync"
	"time"

	"github.com/gorilla/websocket"
)

func main() {
	// 设置要压测的WebSocket服务地址
	url := "ws://192.168.252.128:8080/"

	// 设置并发请求数量
	concurrency := 100

	// 设置每个连接的发送消息数量
	messageCount := 10

	// 等待组,用于等待所有goroutine完成
	var wg sync.WaitGroup

	// 开始时间
	startTime := time.Now()

	// 根据并发请求数量启动goroutine
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			// 建立连接
			conn, _, err := websocket.DefaultDialer.Dial(url, nil)
			if err != nil {
				log.Fatal("无法建立WebSocket连接:", err)
			}
			defer conn.Close()

			// 发送指定数量的消息
			for j := 0; j < messageCount; j++ {
				message := fmt.Sprintf("这是第 %d 个消息", j)
				err = conn.WriteMessage(websocket.TextMessage, []byte(message))
				if err != nil {
					log.Println("发送消息失败:", err)
					return
				}
				log.Println("发送消息:", message)

				// 在发送每个消息之后,可以在这里等待一段时间以模拟实际场景的负载
				//time.Sleep(time.Second)
			}

			// 读取服务器响应
			_, message, err := conn.ReadMessage()
			if err != nil {
				log.Println("读取服务器响应失败:", err)
				return
			}
			log.Println("服务器响应:", string(message))
		}()
	}

	// 等待所有goroutine完成
	wg.Wait()

	// 计算执行时间
	executionTime := time.Since(startTime)

	// 打印统计信息
	log.Printf("压测完成,共发送 %d 个消息,每个连接 %d 个消息\n", concurrency*messageCount, messageCount)
	log.Printf("总执行时间: %.2f 秒\n", executionTime.Seconds())
	log.Printf("每秒请求数: %.2f\n", float64(concurrency*messageCount)/executionTime.Seconds())
}

长连接版

package main

import (
	"log"
	"net/url"
	"sync"
	"time"

	"github.com/gorilla/websocket"
)

func main() {
	// 设置连接的 WebSocket 服务器地址
	u := url.URL{Scheme: "ws", Host: "192.168.252.128:8080", Path: "/"}

	log.Printf("连接到服务器:%s", u.String())

	// 设置并发连接数
	concurrency := 10000

	// 等待组,用于等待所有连接的goroutine完成
	var wg sync.WaitGroup

	// 创建并发连接
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func(j int) {
			defer wg.Done()

			for {
				conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
				if err != nil {
					log.Fatal("连接错误:", err)
				}

				done := make(chan struct{})
				quit := make(chan struct{})

				// 启动 goroutine 用于接收服务器消息
				go func() {
					defer close(done)
					for {
						_, message, err := conn.ReadMessage()
						if err != nil {
							log.Println("读取消息错误:", err)
							quit <- struct{}{}
							return
						}
						log.Printf("收到服务器消息:%s  %d", message, j)
					}
				}()

				// 启动 goroutine 用于定时发送消息至服务器
				go func() {
					ticker := time.NewTicker(time.Second) // 发送间隔
					defer ticker.Stop()
					for {
						select {
						case <-ticker.C:
							err := conn.WriteMessage(websocket.TextMessage, []byte("Hello, server!"))
							if err != nil {
								log.Println("发送消息错误:", err)
								quit <- struct{}{}
								return
							}
						case <-quit:
							return
						}
					}
				}()
				<-done
				time.Sleep(3 * time.Second) // 延迟一段时间后重新连接
			}
		}(i)
	}

	// 等待所有连接的goroutine完成
	wg.Wait()
}

大牛写的长连接版

package main

import (
	"flag"
	"fmt"
	"github.com/gorilla/websocket"
	"io"
	"log"
	"net/url"
	"os"
	"time"
)

var (
	ip          = flag.String("ip", "127.0.0.1", "远程地址")
	connections = flag.Int("conn", 1, "并发连接")
)

func main() {
	flag.Usage = func() {
		io.WriteString(os.Stderr, `go run main.go -ip=129.0.0.1 -conn=10`)
		flag.PrintDefaults()
	}
	flag.Parse()

	u := url.URL{Scheme: "ws", Host: *ip + ":8000", Path: "/"}
	log.Printf("远程地址 %s", u.String())
	var conns []*websocket.Conn
	for i := 0; i < *connections; i++ {
		c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
		if err != nil {
			fmt.Println("连接失败", i, err)
			break
		}
		conns = append(conns, c)
		defer func() {
			c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
			time.Sleep(time.Second)
			c.Close()
		}()
	}

	log.Printf("初始化连接数 %d", len(conns))
	tts := time.Second
	if *connections > 100 {
		tts = time.Millisecond * 5
	}
	for {
		for i := 0; i < len(conns); i++ {
			time.Sleep(tts)
			conn := conns[i]
			log.Printf("Conn %d sending message", i)
			if err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(time.Second*5)); err != nil {
				fmt.Printf("心跳失败: %v", err)
			}
			conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Hello from conn %v", i)))
		}
	}
}

简单server

package main

import (
	"fmt"
	"log"
	"net/http"
	"sync"
	"time"

	"os/exec"

	"github.com/gorilla/websocket"
)

type Client struct {
	conn          *websocket.Conn
	id            string
	send          chan []byte
	lastHeartbeat time.Time
}

var (
	clients = make(map[string]*Client)
	mutex   sync.RWMutex
)

var upgrader = websocket.Upgrader{
	CheckOrigin: func(r *http.Request) bool {
		// 自定义的检查逻辑,可以根据需要进行修改
		// 这里简单地返回 true 允许所有来源
		return true
	},
}

func main() {
	http.HandleFunc("/ws", handleWebSocket)
	err := http.ListenAndServe(":8082", nil)
	if err != nil {
		fmt.Println(err)
	}
}

func handleWebSocket(w http.ResponseWriter, r *http.Request) {
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Println("Upgrade error:", err)
		return
	}

	id, err := exec.Command("uuidgen").Output()
	id1 := string(id)
	fmt.Println(id)
	mutex.RLock()
	//判断是否是断线重连
	if _, ok := clients[id1]; ok {
		clients[id1].conn = conn
		client := clients[id1]
		go client.readPump()
		go client.writePump()
		fmt.Printf("断线连接成功%v", client)
	} else {
		client := &Client{
			conn:          conn,
			id:            id1,
			send:          make(chan []byte),
			lastHeartbeat: time.Now(),
		}
		fmt.Printf("连接成功%v", client)
		clients[id1] = client
		go client.readPump()
		go client.writePump()
	}
	mutex.RUnlock()
	// 检查心跳
	go checkHeartbeat(id1)
}

func (client *Client) readPump() {
	defer func() {
		mutex.Lock()
		delete(clients, client.id)
		mutex.Unlock()
		client.conn.Close()
	}()

	for {
		_, message, err := client.conn.ReadMessage()
		if err != nil {
			if !websocket.IsCloseError(err, websocket.CloseNormalClosure) {
				log.Println("Read error:", err)
			}
			break
		}
		// 处理收到的消息
		log.Printf("Received message from client %s: %s\n", client.id, string(message))
		// 更新最后收到心跳的时间
		client.lastHeartbeat = time.Now()
	}
}

func (client *Client) writePump() {
	defer func() {
		client.conn.Close()
	}()

	for {
		select {
		case message, ok := <-client.send:
			if !ok {
				err := client.conn.WriteMessage(websocket.CloseMessage, []byte{})
				if err != nil {
					log.Println("Write error:", err)
				}
				return
			}
			err := client.conn.WriteMessage(websocket.TextMessage, message)
			if err != nil {
				log.Println("Write error:", err)
				return
			}
		}
	}
}

func checkHeartbeat(id string) {
	ticker := time.NewTicker(5 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			// if client, ok := clients[id]; ok {
			// 	// 检查最后心跳时间,超过指定时间则踢掉客户端
			// 	//if time.Since(client.lastHeartbeat) > 10*time.Second {
			// 	//	client.send <- []byte("heartbeat timeout, disconnecting")
			// 	//	close(client.send)
			// 	//	return
			// 	//}

			// 	// 发送心跳消息
			// 	//client.send <- []byte("heartbeat")
			// } else {
			// 	return
			// }
		}
	}
}
posted @ 2023-12-25 10:50  朝阳1  阅读(104)  评论(0编辑  收藏  举报