Commit c3be5a33 by guoxuejian

Enhance WebSocket session management by tracking principal names and improving…

Enhance WebSocket session management by tracking principal names and improving closed session handling
parent 6112355b
...@@ -32,6 +32,7 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler { ...@@ -32,6 +32,7 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler {
private ObjectMapper objectMapper = new ObjectMapper(); private ObjectMapper objectMapper = new ObjectMapper();
private Map<String, Boolean> authenticatedSessions = new ConcurrentHashMap<>(); private Map<String, Boolean> authenticatedSessions = new ConcurrentHashMap<>();
private Map<String, CustomClaim> sessionClaims = new ConcurrentHashMap<>(); private Map<String, CustomClaim> sessionClaims = new ConcurrentHashMap<>();
private Map<String, String> sessionPrincipalNames = new ConcurrentHashMap<>();
MyWebSocketHandler(WebSocketHandler delegate, IWebSocketManageService webSocketManageService) { MyWebSocketHandler(WebSocketHandler delegate, IWebSocketManageService webSocketManageService) {
super(delegate); super(delegate);
...@@ -46,6 +47,8 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler { ...@@ -46,6 +47,8 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler {
if (StringUtils.hasText(principalName) && !principalName.startsWith("temp-")) { if (StringUtils.hasText(principalName) && !principalName.startsWith("temp-")) {
webSocketManageService.put(principalName, new MyConcurrentWebSocketSession(session)); webSocketManageService.put(principalName, new MyConcurrentWebSocketSession(session));
authenticatedSessions.put(session.getId(), true); authenticatedSessions.put(session.getId(), true);
// 记录 principalName 用于 afterConnectionClosed 时清理
sessionPrincipalNames.put(session.getId(), principalName);
log.debug("{} is connected (pre-authenticated). ID: {}. WebSocketSession[current count: {}]", log.debug("{} is connected (pre-authenticated). ID: {}. WebSocketSession[current count: {}]",
principalName, session.getId(), webSocketManageService.getConnectedCount()); principalName, session.getId(), webSocketManageService.getConnectedCount());
} else { } else {
...@@ -60,18 +63,26 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler { ...@@ -60,18 +63,26 @@ public class MyWebSocketHandler extends WebSocketDefaultHandler {
Boolean isAuthenticated = authenticatedSessions.get(sessionId); Boolean isAuthenticated = authenticatedSessions.get(sessionId);
if (Boolean.TRUE.equals(isAuthenticated)) { if (Boolean.TRUE.equals(isAuthenticated)) {
// 优先用 sessionClaims(消息认证方式),其次用 sessionPrincipalNames(URL token 认证方式)
CustomClaim claim = sessionClaims.get(sessionId); CustomClaim claim = sessionClaims.get(sessionId);
if (claim != null) { if (claim != null) {
String key = claim.getWorkspaceId() + "/" + claim.getUserType() + "/" + claim.getId(); String key = claim.getWorkspaceId() + "/" + claim.getUserType() + "/" + claim.getId();
webSocketManageService.remove(key, sessionId); webSocketManageService.remove(key, sessionId);
log.debug("{} is disconnected. ID: {}. WebSocketSession[current count: {}]", log.debug("{} is disconnected. ID: {}. WebSocketSession[current count: {}]",
key, sessionId, webSocketManageService.getConnectedCount()); key, sessionId, webSocketManageService.getConnectedCount());
} else {
String principalName = sessionPrincipalNames.get(sessionId);
if (principalName != null) {
webSocketManageService.remove(principalName, sessionId);
log.debug("{} is disconnected (pre-auth). ID: {}. WebSocketSession[current count: {}]",
principalName, sessionId, webSocketManageService.getConnectedCount());
}
} }
} }
authenticatedSessions.remove(sessionId); authenticatedSessions.remove(sessionId);
sessionClaims.remove(sessionId); sessionClaims.remove(sessionId);
sessionPrincipalNames.remove(sessionId);
} }
@Override @Override
......
...@@ -63,9 +63,12 @@ public class WebSocketMessageServiceImpl implements IWebSocketMessageService { ...@@ -63,9 +63,12 @@ public class WebSocketMessageServiceImpl implements IWebSocketMessageService {
for (MyConcurrentWebSocketSession session : sessions) { for (MyConcurrentWebSocketSession session : sessions) {
if (!session.isOpen()) { if (!session.isOpen()) {
session.close(); try {
log.debug("This session is closed."); session.close();
return; } catch (Exception ignored) {
}
log.debug("Skipping closed session: {}", session.getId());
continue;
} }
session.sendMessage(data); session.sendMessage(data);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment