package com.mes.tools;
|
|
import cn.hutool.json.JSONObject;
|
import cn.hutool.json.JSONUtil;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
import org.springframework.context.ConfigurableApplicationContext;
|
import org.springframework.stereotype.Component;
|
|
import javax.websocket.*;
|
import javax.websocket.server.PathParam;
|
import javax.websocket.server.ServerEndpoint;
|
import java.io.IOException;
|
import java.util.*;
|
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.Semaphore;
|
|
@ServerEndpoint(value = "/api/talk/{username}")
|
@Component
|
public class WebSocketServer {
|
|
private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);
|
private static final int MAX_MESSAGE_SIZE = 50000; // 单次消息分块阈值
|
private static final Semaphore semaphore = new Semaphore(100); // 流量控制
|
|
// 按用户名分组存储Session(线程安全)
|
private static final Map<String, List<WebSocketServer>> sessionMap = new ConcurrentHashMap<>();
|
|
// 当前连接的用户名和Session
|
private String webSocketName;
|
private Session session;
|
private final List<String> messages = new CopyOnWriteArrayList<>(); // 线程安全消息记录
|
|
/**
|
* 连接建立
|
*/
|
@OnOpen
|
public void onOpen(Session session, @PathParam("webSocketName") String webSocketName) {
|
this.webSocketName = webSocketName;
|
this.session = session;
|
|
sessionMap.computeIfAbsent(webSocketName, k -> new CopyOnWriteArrayList<>())
|
.add(this);
|
|
log.info("用户连接: webSocketName={}, 当前会话数: {}", webSocketName,
|
sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size());
|
}
|
|
/**
|
* 连接关闭
|
*/
|
@OnClose
|
public void onClose() {
|
List<WebSocketServer> sessions = sessionMap.get(webSocketName);
|
if (sessions != null) {
|
sessions.remove(this);
|
if (sessions.isEmpty()) {
|
sessionMap.remove(webSocketName);
|
}
|
log.info("用户断开: webSocketName={}, 剩余会话数: {}", webSocketName,
|
sessionMap.getOrDefault(webSocketName, Collections.emptyList()).size());
|
}
|
}
|
|
/**
|
* 接收消息
|
*/
|
@OnMessage
|
public void onMessage(String message) {
|
log.info("收到消息: webSocketName={}, content={}", webSocketName, message);
|
JSONObject obj = JSONUtil.parseObj(message);
|
messages.add(obj.getStr("data")); // 存储消息历史
|
}
|
|
/**
|
* 错误处理
|
*/
|
@OnError
|
public void onError(Throwable error) {
|
log.error("WebSocket错误: webSocketName={}", webSocketName, error);
|
}
|
|
/**
|
* 向当前用户的所有会话发送消息
|
*/
|
public void sendToWeb(String webSocketName, String message) {
|
List<WebSocketServer> sessions = sessionMap.get(webSocketName);
|
if (sessions == null) return;
|
sessions.forEach(ws -> {
|
try {
|
semaphore.acquire();
|
ws.sendChunkedMessage(message);
|
} catch (InterruptedException e) {
|
Thread.currentThread().interrupt();
|
} catch (Exception e) {
|
log.error("推送失败: webSocketName={}", webSocketName, e);
|
} finally {
|
semaphore.release();
|
}
|
});
|
}
|
|
/**
|
* 分块发送大消息
|
*/
|
private void sendChunkedMessage(String message) {
|
if (!session.isOpen()) return;
|
try {
|
if (message.length() <= MAX_MESSAGE_SIZE) {
|
session.getBasicRemote().sendText(message);
|
return;
|
}
|
|
// 分块发送
|
int chunks = (int) Math.ceil((double) message.length() / MAX_MESSAGE_SIZE);
|
for (int i = 0; i < chunks; i++) {
|
int start = i * MAX_MESSAGE_SIZE;
|
int end = Math.min(start + MAX_MESSAGE_SIZE, message.length());
|
String chunk = message.substring(start, end) + (i == chunks - 1 ? "<END>" : "");
|
session.getBasicRemote().sendText(chunk);
|
}
|
} catch (IOException e) {
|
log.error("消息发送失败: webSocketName={}", webSocketName, e);
|
}
|
}
|
|
// --- 工具方法 ---
|
public static Set<String> getOnlineUsers() {
|
return sessionMap.keySet();
|
}
|
|
public List<String> getMessages() {
|
return Collections.unmodifiableList(messages);
|
}
|
|
public void clearMessages() {
|
messages.clear();
|
}
|
}
|