ZengTao
2025-08-22 04801febfbdf46d6cb34724862c640f27887be96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package com.mes.tools;
 
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Component;
 
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Semaphore;
 
@ServerEndpoint(value = "/api/talk/{webSocketName}")
@Component
public class WebSocketServer {
 
    private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);
    private static final int MAX_MESSAGE_SIZE = 50000; // 单次消息分块阈值
    private static final Semaphore semaphore = new Semaphore(100); // 流量控制
 
    // 按用户名分组存储Session(线程安全)
    private static final Map<String, List<WebSocketServer>> sessionMap = new ConcurrentHashMap<>();
 
    // 当前连接的用户名和Session
    private String webSocketName;
    private Session session;
    private final List<String> messages = new CopyOnWriteArrayList<>(); // 线程安全消息记录
 
    /**
     * 连接建立
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("webSocketName") String webSocketName) {
        this.webSocketName = webSocketName;
        this.session = session;
 
        sessionMap.computeIfAbsent(webSocketName, k -> new CopyOnWriteArrayList<>())
                .add(this);
 
        log.info("用户连接: webSocketName={}, 当前会话数: {}", webSocketName,
                sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size());
    }
 
    /**
     * 连接关闭
     */
    @OnClose
    public void onClose() {
        List<WebSocketServer> sessions = sessionMap.get(webSocketName);
        if (sessions != null) {
            sessions.remove(this);
            if (sessions.isEmpty()) {
                sessionMap.remove(webSocketName);
            }
            log.info("用户断开: webSocketName={}, 剩余会话数: {}", webSocketName,
                    sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size());
        }
    }
 
    /**
     * 接收消息
     */
    @OnMessage
    public void onMessage(String message) {
        log.info("收到消息: webSocketName={}, content={}", webSocketName, message);
        JSONObject obj = JSONUtil.parseObj(message);
        messages.add(obj.getStr("data")); // 存储消息历史
    }
 
    /**
     * 错误处理
     */
    @OnError
    public void onError(Throwable error) {
        log.error("WebSocket错误: webSocketName={}", webSocketName, error);
    }
 
    /**
     * 向当前用户的所有会话发送消息
     */
    public void sendToWeb(String webSocketName, String message) {
        List<WebSocketServer> sessions = sessionMap.get(webSocketName);
        if (sessions == null) return;
        sessions.forEach(ws -> {
            try {
                semaphore.acquire();
                ws.sendChunkedMessage(message);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch (Exception e) {
                log.error("推送失败: webSocketName={}", webSocketName, e);
            } finally {
                semaphore.release();
            }
        });
    }
 
    /**
     * 分块发送大消息
     */
    private void sendChunkedMessage(String message) {
        if (!session.isOpen()) return;
        try {
            if (message.length() <= MAX_MESSAGE_SIZE) {
                session.getBasicRemote().sendText(message);
                return;
            }
 
            // 分块发送
            int chunks = (int) Math.ceil((double) message.length() / MAX_MESSAGE_SIZE);
            for (int i = 0; i < chunks; i++) {
                int start = i * MAX_MESSAGE_SIZE;
                int end = Math.min(start + MAX_MESSAGE_SIZE, message.length());
                String chunk = message.substring(start, end) + (i == chunks - 1 ? "<END>" : "");
                session.getBasicRemote().sendText(chunk);
            }
        } catch (IOException e) {
            log.error("消息发送失败: webSocketName={}", webSocketName, e);
        }
    }
 
    // --- 工具方法 ---
    public static Set<String> getOnlineUsers() {
        return sessionMap.keySet();
    }
 
    public List<String> getMessages() {
        return Collections.unmodifiableList(messages);
    }
 
    public void clearMessages() {
        messages.clear();
    }
}