提交 781d88fe 作者: 方治民

test: stomp + simple 模式共存的适配(未完成)

上级 68ebdcde
...@@ -4,15 +4,14 @@ env: ...@@ -4,15 +4,14 @@ env:
prod: false prod: false
props: props:
username: admin username: admin
password: Hd)XZgtCa&NG~oe@ password: 123456
spring: spring:
datasource: datasource:
url: jdbc:h2:file:~/h2_basic;DB_CLOSE_ON_EXIT=FALSE url: jdbc:h2:file:~/h2_basic;DB_CLOSE_ON_EXIT=FALSE;;NON_KEYWORDS=VALUE
username: sa username: sa
password: 123456 password: 123456
jpa: jpa:
database-platform: org.hibernate.dialect.H2Dialect
show-sql: true show-sql: true
open-in-view: true open-in-view: true
hibernate: hibernate:
...@@ -26,6 +25,12 @@ spring: ...@@ -26,6 +25,12 @@ spring:
port: 6379 port: 6379
host: ${env.host} host: ${env.host}
password: ${env.props.password} password: ${env.props.password}
rabbitmq:
port: 5672
username: ${env.props.username}
password: ${env.props.password}
virtual-host: admin
# stomp-port: 61613
# knife4j # knife4j
knife4j: knife4j:
......
...@@ -16,8 +16,8 @@ spring: ...@@ -16,8 +16,8 @@ spring:
max-file-size: 1024MB max-file-size: 1024MB
max-request-size: 1048MB max-request-size: 1048MB
profiles: profiles:
include: auth, conf-patch, monitor include: auth, conf-patch #, monitor
active: dev-postgresql active: mock
# DEBUG # DEBUG
debug: false debug: true
/* (C) 2022 YiRing, Inc. */ /* (C) 2022 YiRing, Inc. */
package com.yiring.websocket.config; package com.yiring.websocket.config;
import cn.hutool.core.convert.Convert;
import cn.hutool.extra.spring.SpringUtil; import cn.hutool.extra.spring.SpringUtil;
import com.yiring.common.core.Redis; import com.yiring.common.core.Redis;
import com.yiring.websocket.constant.RedisKey; import com.yiring.websocket.constant.RedisKey;
import com.yiring.websocket.interceptor.ClientInboundChannelInterceptor; import com.yiring.websocket.interceptor.ClientInboundChannelInterceptor;
import com.yiring.websocket.interceptor.ClientOutboundChannelInterceptor; import com.yiring.websocket.interceptor.ClientOutboundChannelInterceptor;
import java.util.Objects;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.amqp.RabbitProperties; import org.springframework.boot.autoconfigure.amqp.RabbitProperties;
...@@ -37,6 +35,17 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -37,6 +35,17 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
final ClientInboundChannelInterceptor clientInboundChannelInterceptor; final ClientInboundChannelInterceptor clientInboundChannelInterceptor;
final ClientOutboundChannelInterceptor clientOutboundChannelInterceptor; final ClientOutboundChannelInterceptor clientOutboundChannelInterceptor;
public static Integer stompPort;
public static boolean simpleMode;
public static String mode;
// @PostConstruct
// public void init() {
//// stompPort = Convert.toInt(SpringUtil.getProperty("spring.rabbitmq.stomp-port"));
// simpleMode = Objects.isNull(stompPort);
// mode = simpleMode ? "Simple" : "STOMP";
// }
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry registry
...@@ -50,21 +59,20 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -50,21 +59,20 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
.setAllowedOriginPatterns("*") .setAllowedOriginPatterns("*")
.addInterceptors(new HttpSessionHandshakeInterceptor()); .addInterceptors(new HttpSessionHandshakeInterceptor());
log.info("Init STOMP Endpoints Success."); log.info("WebSocket(Mode: {}) init endpoints success.", mode);
} }
@Override @Override
public void configureMessageBroker(MessageBrokerRegistry registry) { public void configureMessageBroker(MessageBrokerRegistry registry) {
// 启动前先删除掉可能存在的残留STOMP连接缓存数据 // 启动前先删除掉可能存在的残留STOMP连接缓存数据
redis.del(RedisKey.STOMP_ONLINE_USERS); redis.del(RedisKey.STOMP_ONLINE_USERS);
log.info("Clear STOMP online user info cache of redis."); log.info("WebSocket(Mode: {}) clear online user info cache of redis.", mode);
registry.setPreservePublishOrder(true); registry.setPreservePublishOrder(true);
registry.setUserDestinationPrefix("/user"); registry.setUserDestinationPrefix("/user");
registry.setApplicationDestinationPrefixes("/app"); registry.setApplicationDestinationPrefixes("/app");
String stompPort = SpringUtil.getProperty("spring.rabbitmq.stomp-port"); if (simpleMode) {
if (Objects.isNull(stompPort)) {
// 1. 使用内存方式处理消息 // 1. 使用内存方式处理消息
registry.enableSimpleBroker("/topic", "/queue"); registry.enableSimpleBroker("/topic", "/queue");
} else { } else {
...@@ -72,7 +80,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -72,7 +80,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
RabbitProperties rabbitProperties = SpringUtil.getBean(RabbitProperties.class); RabbitProperties rabbitProperties = SpringUtil.getBean(RabbitProperties.class);
registry registry
.enableStompBrokerRelay("/topic", "/queue") .enableStompBrokerRelay("/topic", "/queue")
.setRelayPort(Convert.toInt(stompPort)) .setRelayPort(stompPort)
.setRelayHost(rabbitProperties.getHost()) .setRelayHost(rabbitProperties.getHost())
.setVirtualHost(rabbitProperties.getVirtualHost()) .setVirtualHost(rabbitProperties.getVirtualHost())
.setClientLogin(rabbitProperties.getUsername()) .setClientLogin(rabbitProperties.getUsername())
...@@ -81,7 +89,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer { ...@@ -81,7 +89,7 @@ public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {
.setSystemPasscode(rabbitProperties.getPassword()); .setSystemPasscode(rabbitProperties.getPassword());
} }
log.info("Init RabbitMQ STOMP MessageBroker Success."); log.info("WebSocket(Mode: {}) init messageBroker success.", mode);
} }
@Override @Override
......
/* (C) 2024 YiRing, Inc. */
package com.yiring.websocket.interceptor;
import com.yiring.websocket.config.WebSocketStompConfig;
import lombok.NonNull;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.stereotype.Component;
/**
* @author Jim
*/
@Component
public class AbstractMessageHandler {
public SimpMessageHeaderAccessor getAccessor(@NonNull Message<?> message) {
Class<? extends SimpMessageHeaderAccessor> clazz = WebSocketStompConfig.simpleMode
? SimpMessageHeaderAccessor.class
: StompHeaderAccessor.class;
return MessageHeaderAccessor.getAccessor(message, clazz);
}
public boolean isConnect(@NonNull SimpMessageHeaderAccessor accessor) {
if (accessor instanceof StompHeaderAccessor) {
return StompCommand.CONNECT.equals(((StompHeaderAccessor) accessor).getCommand());
} else {
return SimpMessageType.CONNECT.equals(accessor.getMessageType());
}
}
public boolean isDisconnect(@NonNull SimpMessageHeaderAccessor accessor) {
if (accessor instanceof StompHeaderAccessor) {
return StompCommand.DISCONNECT.equals(((StompHeaderAccessor) accessor).getCommand());
} else {
return SimpMessageType.DISCONNECT.equals(accessor.getMessageType());
}
}
public Boolean isConnected(@NonNull SimpMessageHeaderAccessor accessor) {
if (accessor instanceof StompHeaderAccessor) {
return StompCommand.CONNECTED.equals(((StompHeaderAccessor) accessor).getCommand());
} else {
return SimpMessageType.CONNECT_ACK.equals(accessor.getMessageType());
}
}
}
...@@ -5,6 +5,7 @@ import cn.hutool.core.convert.Convert; ...@@ -5,6 +5,7 @@ import cn.hutool.core.convert.Convert;
import com.yiring.auth.domain.user.User; import com.yiring.auth.domain.user.User;
import com.yiring.auth.util.Auths; import com.yiring.auth.util.Auths;
import com.yiring.common.core.Redis; import com.yiring.common.core.Redis;
import com.yiring.websocket.config.WebSocketStompConfig;
import com.yiring.websocket.constant.RedisKey; import com.yiring.websocket.constant.RedisKey;
import com.yiring.websocket.domain.StompPrincipal; import com.yiring.websocket.domain.StompPrincipal;
import java.util.Collection; import java.util.Collection;
...@@ -15,10 +16,7 @@ import lombok.extern.slf4j.Slf4j; ...@@ -15,10 +16,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
...@@ -38,14 +36,16 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor { ...@@ -38,14 +36,16 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor {
final Redis redis; final Redis redis;
final Auths auths; final Auths auths;
final AbstractMessageHandler handler;
private final Object lock = new Object(); private final Object lock = new Object();
@Override @Override
public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) { public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
assert accessor != null; assert accessor != null;
if (StompCommand.CONNECT.equals(accessor.getCommand())) {
if (handler.isConnect(accessor)) {
Object raw = message.getHeaders().get(NativeMessageHeaderAccessor.NATIVE_HEADERS); Object raw = message.getHeaders().get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
if (raw instanceof Map) { if (raw instanceof Map) {
StompPrincipal principal = new StompPrincipal(); StompPrincipal principal = new StompPrincipal();
...@@ -67,7 +67,8 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor { ...@@ -67,7 +67,8 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor {
synchronized (lock) { synchronized (lock) {
redis.hset(RedisKey.STOMP_ONLINE_USERS, principal.getSession(), principal); redis.hset(RedisKey.STOMP_ONLINE_USERS, principal.getSession(), principal);
log.info( log.info(
"STOMP Online Users: {} (incr: +1, user: {}, session: {}, token: {})", "WebSocket(Mode: {}) Online Users: {} (incr: +1, user: {}, session: {}, token: {})",
WebSocketStompConfig.mode,
redis.hsize(RedisKey.STOMP_ONLINE_USERS), redis.hsize(RedisKey.STOMP_ONLINE_USERS),
principal.getUser(), principal.getUser(),
principal.getSession(), principal.getSession(),
...@@ -75,13 +76,14 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor { ...@@ -75,13 +76,14 @@ public class ClientInboundChannelInterceptor implements ChannelInterceptor {
); );
} }
} }
} else if (StompCommand.DISCONNECT.equals(accessor.getCommand())) { } else if (handler.isDisconnect(accessor)) {
StompPrincipal principal = (StompPrincipal) accessor.getUser(); StompPrincipal principal = (StompPrincipal) accessor.getUser();
if (principal != null && !message.getHeaders().containsKey(SimpMessageHeaderAccessor.HEART_BEAT_HEADER)) { if (principal != null && !message.getHeaders().containsKey(SimpMessageHeaderAccessor.HEART_BEAT_HEADER)) {
synchronized (lock) { synchronized (lock) {
redis.hdel(RedisKey.STOMP_ONLINE_USERS, principal.getSession()); redis.hdel(RedisKey.STOMP_ONLINE_USERS, principal.getSession());
log.info( log.info(
"STOMP Online Users: {} (incr: -1, user: {}, session: {}, token: {})", "WebSocket(Mode: {}) Online Users: {} (incr: -1, user: {}, session: {}, token: {})",
WebSocketStompConfig.mode,
redis.hsize(RedisKey.STOMP_ONLINE_USERS), redis.hsize(RedisKey.STOMP_ONLINE_USERS),
principal.getUser(), principal.getUser(),
principal.getSession(), principal.getSession(),
......
...@@ -4,14 +4,13 @@ package com.yiring.websocket.interceptor; ...@@ -4,14 +4,13 @@ package com.yiring.websocket.interceptor;
import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.JSON;
import com.yiring.websocket.domain.StompPrincipal; import com.yiring.websocket.domain.StompPrincipal;
import lombok.NonNull; import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
/** /**
...@@ -25,13 +24,17 @@ import org.springframework.stereotype.Component; ...@@ -25,13 +24,17 @@ import org.springframework.stereotype.Component;
@Slf4j @Slf4j
@Component @Component
@RequiredArgsConstructor
public class ClientOutboundChannelInterceptor implements ChannelInterceptor { public class ClientOutboundChannelInterceptor implements ChannelInterceptor {
final AbstractMessageHandler handler;
@Override @Override
public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) { public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
assert accessor != null; assert accessor != null;
if (StompCommand.CONNECTED.equals(accessor.getCommand())) {
if (handler.isConnected(accessor)) {
StompPrincipal principal = (StompPrincipal) accessor.getUser(); StompPrincipal principal = (StompPrincipal) accessor.getUser();
return MessageBuilder.createMessage(JSON.toJSONBytes(principal), message.getHeaders()); return MessageBuilder.createMessage(JSON.toJSONBytes(principal), message.getHeaders());
} }
......
...@@ -2,17 +2,18 @@ ...@@ -2,17 +2,18 @@
package com.yiring.websocket.registry; package com.yiring.websocket.registry;
import com.yiring.websocket.domain.StompPrincipal; import com.yiring.websocket.domain.StompPrincipal;
import com.yiring.websocket.interceptor.AbstractMessageHandler;
import java.security.Principal; import java.security.Principal;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull; import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener; import org.springframework.context.event.SmartApplicationListener;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent; import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
...@@ -28,8 +29,11 @@ import org.springframework.web.socket.messaging.SessionDisconnectEvent; ...@@ -28,8 +29,11 @@ import org.springframework.web.socket.messaging.SessionDisconnectEvent;
*/ */
@Component @Component
@RequiredArgsConstructor
public class CustomStompUserRegistry implements StompUserRegistry, SmartApplicationListener { public class CustomStompUserRegistry implements StompUserRegistry, SmartApplicationListener {
final AbstractMessageHandler handler;
/** /**
* sessionId, Principal * sessionId, Principal
*/ */
...@@ -47,23 +51,22 @@ public class CustomStompUserRegistry implements StompUserRegistry, SmartApplicat ...@@ -47,23 +51,22 @@ public class CustomStompUserRegistry implements StompUserRegistry, SmartApplicat
AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event; AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
Message<?> message = subProtocolEvent.getMessage(); Message<?> message = subProtocolEvent.getMessage();
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor( if (event instanceof SessionConnectedEvent || event instanceof SessionDisconnectEvent) {
message, SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
SimpMessageHeaderAccessor.class Assert.state(accessor != null, "No Accessor");
);
Assert.state(accessor != null, "No SimpMessageHeaderAccessor");
String sessionId = accessor.getSessionId(); String sessionId = accessor.getSessionId();
Assert.state(sessionId != null, "No session id"); Assert.state(sessionId != null, "No session id");
if (event instanceof SessionConnectedEvent) { if (event instanceof SessionConnectedEvent) {
Principal user = subProtocolEvent.getUser(); Principal user = subProtocolEvent.getUser();
synchronized (lock) { synchronized (lock) {
this.users.put(sessionId, (StompPrincipal) user); this.users.put(sessionId, (StompPrincipal) user);
} }
} else if (event instanceof SessionDisconnectEvent) { } else {
synchronized (lock) { synchronized (lock) {
this.users.remove(sessionId); this.users.remove(sessionId);
}
} }
} }
} }
......
...@@ -87,8 +87,11 @@ public class StompReceiver { ...@@ -87,8 +87,11 @@ public class StompReceiver {
public void test(StompHeaderAccessor accessor, String message) { public void test(StompHeaderAccessor accessor, String message) {
log.info("收到来自 STOMP Client `/app/ping` 消息:{}", message); log.info("收到来自 STOMP Client `/app/ping` 消息:{}", message);
Set<SimpUser> users = simpUserRegistry.getUsers(); Set<SimpUser> simpUsers = simpUserRegistry.getUsers();
log.info("{}", users); log.info("SimpUsers: {}", simpUsers);
Set<StompPrincipal> stompPrincipals = stompUserRegistry.getUsers();
log.info("StompPrincipals: {}", stompPrincipals);
JSONObject body = new JSONObject(); JSONObject body = new JSONObject();
body.put("message", "pong"); body.put("message", "pong");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论