package com.mes.tools; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.stereotype.Component; import javax.websocket.*; import javax.websocket.server.PathParam; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Semaphore; @ServerEndpoint(value = "/api/talk/{webSocketName}") @Component public class WebSocketServer { private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class); private static final int MAX_MESSAGE_SIZE = 50000; // 单次消息分块阈值 private static final Semaphore semaphore = new Semaphore(100); // 流量控制 // 按用户名分组存储Session(线程安全) private static final Map> sessionMap = new ConcurrentHashMap<>(); // 当前连接的用户名和Session private String webSocketName; private Session session; private final List messages = new CopyOnWriteArrayList<>(); // 线程安全消息记录 /** * 连接建立 */ @OnOpen public void onOpen(Session session, @PathParam("webSocketName") String webSocketName) { this.webSocketName = webSocketName; this.session = session; sessionMap.computeIfAbsent(webSocketName, k -> new CopyOnWriteArrayList<>()) .add(this); log.info("用户连接: webSocketName={}, 当前会话数: {}", webSocketName, sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size()); } /** * 连接关闭 */ @OnClose public void onClose() { List sessions = sessionMap.get(webSocketName); if (sessions != null) { sessions.remove(this); if (sessions.isEmpty()) { sessionMap.remove(webSocketName); } log.info("用户断开: webSocketName={}, 剩余会话数: {}", webSocketName, sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size()); } } /** * 接收消息 */ @OnMessage public void onMessage(String message) { log.info("收到消息: webSocketName={}, content={}", webSocketName, message); JSONObject obj = JSONUtil.parseObj(message); messages.add(obj.getStr("data")); // 存储消息历史 } /** * 错误处理 */ @OnError public void onError(Throwable error) { log.error("WebSocket错误: webSocketName={}", webSocketName, error); } /** * 向当前用户的所有会话发送消息 */ public void sendToWeb(String webSocketName, String message) { List sessions = sessionMap.get(webSocketName); if (sessions == null) return; sessions.forEach(ws -> { try { semaphore.acquire(); ws.sendChunkedMessage(message); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Exception e) { log.error("推送失败: webSocketName={}", webSocketName, e); } finally { semaphore.release(); } }); } /** * 分块发送大消息 */ private void sendChunkedMessage(String message) { if (!session.isOpen()) return; try { if (message.length() <= MAX_MESSAGE_SIZE) { session.getBasicRemote().sendText(message); return; } // 分块发送 int chunks = (int) Math.ceil((double) message.length() / MAX_MESSAGE_SIZE); for (int i = 0; i < chunks; i++) { int start = i * MAX_MESSAGE_SIZE; int end = Math.min(start + MAX_MESSAGE_SIZE, message.length()); String chunk = message.substring(start, end) + (i == chunks - 1 ? "" : ""); session.getBasicRemote().sendText(chunk); } } catch (IOException e) { log.error("消息发送失败: webSocketName={}", webSocketName, e); } } // --- 工具方法 --- public static Set getOnlineUsers() { return sessionMap.keySet(); } public List getMessages() { return Collections.unmodifiableList(messages); } public void clearMessages() { messages.clear(); } }