main.go
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"syscall"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
var epoller *epoll
func wsHandler(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
return
}
if err := epoller.Add(conn); err != nil {
log.Printf("Failed to add connection %v", err)
conn.Close()
}
}
func main() {
var rLimit syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
panic(err)
}
rLimit.Cur = rLimit.Max
if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
panic(err)
}
go func() {
if err := http.ListenAndServe("localhost:6060", nil); err != nil {
log.Fatalf("pprof failed: %v", err)
}
}()
var err error
epoller, err = MkEpoll()
if err != nil {
panic(err)
}
go Start()
http.HandleFunc("/", wsHandler)
if err := http.ListenAndServe("0.0.0.0:8000", nil); err != nil {
log.Fatal(err)
}
}
func Start() {
for {
connections, err := epoller.Wait()
if err != nil {
if err == syscall.EINTR {
log.Printf("Epoll wait was interrupted, retrying...")
continue
}
log.Printf("Failed to epoll wait: %v", err)
break
}
for _, conn := range connections {
if conn == nil {
break
}
if _, _, err := wsutil.ReadClientData(conn); err != nil {
if err := epoller.Remove(conn); err != nil {
log.Printf("Failed to remove %v", err)
}
conn.Close()
} else {
log.Printf("msg:xxxxxxxxx")
}
}
}
}
epoll.go
package main
import (
"golang.org/x/sys/unix"
"log"
"net"
"reflect"
"sync"
"syscall"
)
type epoll struct {
fd int
connections map[int]net.Conn
lock *sync.RWMutex
}
func MkEpoll() (*epoll, error) {
fd, err := unix.EpollCreate1(0)
if err != nil {
return nil, err
}
return &epoll{
fd: fd,
lock: &sync.RWMutex{},
connections: make(map[int]net.Conn),
}, nil
}
func (e *epoll) Add(conn net.Conn) error {
fd := websocketFD(conn)
err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN | unix.POLLHUP, Fd: int32(fd)})
if err != nil {
return err
}
e.lock.Lock()
defer e.lock.Unlock()
e.connections[fd] = conn
if len(e.connections)%100 == 0 {
log.Printf("Total number of connections: %v", len(e.connections))
}
return nil
}
func (e *epoll) Remove(conn net.Conn) error {
fd := websocketFD(conn)
err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil)
if err != nil {
return err
}
e.lock.Lock()
defer e.lock.Unlock()
delete(e.connections, fd)
if len(e.connections)%100 == 0 {
log.Printf("Total number of connections: %v", len(e.connections))
}
return nil
}
func (e *epoll) Wait() ([]net.Conn, error) {
events := make([]unix.EpollEvent, 100)
n, err := unix.EpollWait(e.fd, events, 100)
if err != nil {
return nil, err
}
e.lock.RLock()
defer e.lock.RUnlock()
var connections []net.Conn
for i := 0; i < n; i++ {
conn := e.connections[int(events[i].Fd)]
connections = append(connections, conn)
}
return connections, nil
}
func websocketFD(conn net.Conn) int {
tcpConn := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn")
fdVal := tcpConn.FieldByName("fd")
pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")
return int(pfdVal.FieldByName("Sysfd").Int())
}
测试客户端
package main
import (
"flag"
"fmt"
"io"
"log"
"net/url"
"os"
"time"
"github.com/gorilla/websocket"
)
var (
ip = flag.String("ip", "127.0.0.1", "server IP")
connections = flag.Int("conn", 10, "number of websocket connections")
)
func main() {
flag.Usage = func() {
io.WriteString(os.Stderr, `Websockets client generator
Example usage: ./client -ip=172.17.0.1 -conn=10
`)
flag.PrintDefaults()
}
flag.Parse()
u := url.URL{Scheme: "ws", Host: *ip + ":8000", Path: "/"}
log.Printf("Connecting to %s", u.String())
var conns []*websocket.Conn
for i := 0; i < *connections; i++ {
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
fmt.Println("Failed to connect", i, err)
break
}
conns = append(conns, c)
defer func() {
c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
time.Sleep(time.Second)
c.Close()
}()
}
log.Printf("Finished initializing %d connections", len(conns))
tts := time.Second
if *connections > 100 {
tts = time.Millisecond * 5
}
for {
for i := 0; i < len(conns); i++ {
time.Sleep(tts)
conn := conns[i]
log.Printf("Conn %d sending message", i)
if err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(time.Second*5)); err != nil {
fmt.Printf("Failed to receive pong: %v", err)
}
conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Hello from conn %v", i)))
}
}
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 通过 API 将Deepseek响应流式内容输出到前端
· 因为Apifox不支持离线,我果断选择了Apipost!