提交 7481fe3d 作者: 方治民

chore: 优化 WebSocket 相关配置和校验

上级 c33f9d96
...@@ -39,24 +39,26 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -39,24 +39,26 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
final ClientOutboundChannelInterceptor clientOutboundChannelInterceptor; final ClientOutboundChannelInterceptor clientOutboundChannelInterceptor;
public static Integer stompPort; public static Integer stompPort;
public static boolean simpleMode; public static boolean stompMode;
public static String mode; public static String mode;
@PostConstruct @PostConstruct
public void init() { public void init() {
stompPort = Convert.toInt(SpringUtil.getProperty("spring.rabbitmq.stomp-port")); stompPort = Convert.toInt(SpringUtil.getProperty("spring.rabbitmq.stomp-port"));
simpleMode = Objects.isNull(stompPort); stompMode = Objects.nonNull(stompPort);
mode = simpleMode ? "Simple" : "STOMP"; mode = stompMode ? "STOMP" : "Simple";
} }
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
// SockJS 连接
registry registry
.addEndpoint("/stomp/sock-js") .addEndpoint("/stomp/sock-js")
.setAllowedOriginPatterns("*") .setAllowedOriginPatterns("*")
.addInterceptors(new HttpSessionHandshakeInterceptor()) .addInterceptors(new HttpSessionHandshakeInterceptor())
.withSockJS(); .withSockJS();
// 原生 WebSocket 连接
registry registry
.addEndpoint("/stomp/ws") .addEndpoint("/stomp/ws")
.setAllowedOriginPatterns("*") .setAllowedOriginPatterns("*")
...@@ -68,6 +70,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -68,6 +70,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
@Override @Override
public void configureMessageBroker(MessageBrokerRegistry registry) { public void configureMessageBroker(MessageBrokerRegistry registry) {
// 启动前先删除掉可能存在的残留STOMP连接缓存数据 // 启动前先删除掉可能存在的残留STOMP连接缓存数据
// FIXME: 没有考虑多服务部署场景,仅单机模式
redis.del(RedisKey.STOMP_ONLINE_USERS); redis.del(RedisKey.STOMP_ONLINE_USERS);
log.info("WebSocket(Mode: {}) clear online user info cache of redis.", mode); log.info("WebSocket(Mode: {}) clear online user info cache of redis.", mode);
...@@ -75,14 +78,12 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -75,14 +78,12 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
registry.setUserDestinationPrefix("/user"); registry.setUserDestinationPrefix("/user");
registry.setApplicationDestinationPrefixes("/app"); registry.setApplicationDestinationPrefixes("/app");
if (simpleMode) { String[] destinationPrefixes = { "/topic", "/queue" };
// 1. 使用内存方式处理消息 if (stompMode) {
registry.enableSimpleBroker("/topic", "/queue"); // 1. 使用 RabbitMQ 处理消息(需要安装 STOMP 插件)
} else {
// 2. 使用 RabbitMQ 处理消息(需要安装 STOMP 插件)
RabbitProperties rabbitProperties = SpringUtil.getBean(RabbitProperties.class); RabbitProperties rabbitProperties = SpringUtil.getBean(RabbitProperties.class);
registry registry
.enableStompBrokerRelay("/topic", "/queue") .enableStompBrokerRelay(destinationPrefixes)
.setRelayPort(stompPort) .setRelayPort(stompPort)
.setRelayHost(rabbitProperties.getHost()) .setRelayHost(rabbitProperties.getHost())
.setVirtualHost(rabbitProperties.getVirtualHost()) .setVirtualHost(rabbitProperties.getVirtualHost())
...@@ -90,6 +91,9 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -90,6 +91,9 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
.setClientPasscode(rabbitProperties.getPassword()) .setClientPasscode(rabbitProperties.getPassword())
.setSystemLogin(rabbitProperties.getUsername()) .setSystemLogin(rabbitProperties.getUsername())
.setSystemPasscode(rabbitProperties.getPassword()); .setSystemPasscode(rabbitProperties.getPassword());
} else {
// 2. 使用内存方式处理消息
registry.enableSimpleBroker(destinationPrefixes);
} }
log.info("WebSocket(Mode: {}) init messageBroker success.", mode); log.info("WebSocket(Mode: {}) init messageBroker success.", mode);
......
...@@ -18,9 +18,9 @@ import org.springframework.stereotype.Component; ...@@ -18,9 +18,9 @@ import org.springframework.stereotype.Component;
public class AbstractMessageHandler { public class AbstractMessageHandler {
public SimpMessageHeaderAccessor getAccessor(@NonNull Message<?> message) { public SimpMessageHeaderAccessor getAccessor(@NonNull Message<?> message) {
Class<? extends SimpMessageHeaderAccessor> clazz = WebSocketStompConfig.simpleMode Class<? extends SimpMessageHeaderAccessor> clazz = WebSocketStompConfig.stompMode
? SimpMessageHeaderAccessor.class ? StompHeaderAccessor.class
: StompHeaderAccessor.class; : SimpMessageHeaderAccessor.class;
return MessageHeaderAccessor.getAccessor(message, clazz); return MessageHeaderAccessor.getAccessor(message, clazz);
} }
......
...@@ -19,6 +19,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; ...@@ -19,6 +19,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
/** /**
* ClientInboundChannelInterceptor * ClientInboundChannelInterceptor
...@@ -43,7 +44,7 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor { ...@@ -43,7 +44,7 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor {
@Override @Override
public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) { public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
SimpMessageHeaderAccessor accessor = handler.getAccessor(message); SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
assert accessor != null; Assert.state(accessor != null, "No accessor");
if (handler.isConnect(accessor)) { if (handler.isConnect(accessor)) {
Object raw = message.getHeaders().get(NativeMessageHeaderAccessor.NATIVE_HEADERS); Object raw = message.getHeaders().get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
...@@ -53,11 +54,16 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor { ...@@ -53,11 +54,16 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor {
Object tokens = ((Map<?, ?>) raw).get("token"); Object tokens = ((Map<?, ?>) raw).get("token");
if (tokens instanceof Collection<?>) { if (tokens instanceof Collection<?>) {
String token = Convert.toList(String.class, tokens).get(0); try {
User user = auths.getUserByToken(token); String token = Convert.toList(String.class, tokens).get(0);
principal.setToken(token); User user = auths.getUserByToken(token);
principal.setUser(user.getRealName()); principal.setToken(token);
principal.setType(StompPrincipal.Type.LOGIN_USER); principal.setUser(user.getRealName());
principal.setType(StompPrincipal.Type.LOGIN_USER);
} catch (Exception e) {
log.warn("WebSocket(Mode: {}) connect failed: {}", WebSocketStompConfig.mode, e.getMessage());
throw e;
}
} else { } else {
principal.setUser("Guest." + principal.getSession()); principal.setUser("Guest." + principal.getSession());
principal.setType(StompPrincipal.Type.GUEST_USER); principal.setType(StompPrincipal.Type.GUEST_USER);
......
...@@ -12,6 +12,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; ...@@ -12,6 +12,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
/** /**
* ClientOutboundChannelInterceptor * ClientOutboundChannelInterceptor
...@@ -32,7 +33,7 @@ public class ClientOutboundChannelInterceptor implements ChannelInterceptor { ...@@ -32,7 +33,7 @@ public class ClientOutboundChannelInterceptor implements ChannelInterceptor {
@Override @Override
public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) { public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
SimpMessageHeaderAccessor accessor = handler.getAccessor(message); SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
assert accessor != null; Assert.state(accessor != null, "No accessor");
if (handler.isConnected(accessor)) { if (handler.isConnected(accessor)) {
StompPrincipal principal = (StompPrincipal) accessor.getUser(); StompPrincipal principal = (StompPrincipal) accessor.getUser();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论