/* (C) 2022 YiRing, Inc. */
package com.yiring.websocket.config;

import cn.hutool.extra.spring.SpringUtil;
import com.yiring.common.core.Redis;
import com.yiring.websocket.constant.RedisKey;
import com.yiring.websocket.interceptor.ClientInboundChannelInterceptor;
import com.yiring.websocket.interceptor.ClientOutboundChannelInterceptor;
import jakarta.annotation.PostConstruct;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.amqp.RabbitProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

/**
 * WebSocketStompConfig
 *
 * @author ifzm
 * @version 0.1
 * 2019/9/25 20:12
 */

@Slf4j
@Configuration
@EnableWebSocketMessageBroker
@RequiredArgsConstructor
public class WebSocketStompConfig implements WebSocketMessageBrokerConfigurer {

    final Redis redis;
    final ClientInboundChannelInterceptor clientInboundChannelInterceptor;
    final ClientOutboundChannelInterceptor clientOutboundChannelInterceptor;

    @Value("${spring.rabbitmq.stomp-port:}")
    public Integer stompPort;

    public static boolean stompMode;
    public static String mode;

    @PostConstruct
    public void init() {
        stompMode = Objects.nonNull(stompPort);
        mode = stompMode ? "STOMP" : "Simple";
    }

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        // SockJS 连接
        registry
            .addEndpoint("/stomp/sock-js")
            .setAllowedOriginPatterns("*")
            .addInterceptors(new HttpSessionHandshakeInterceptor())
            .withSockJS();

        // 原生 WebSocket 连接
        registry
            .addEndpoint("/stomp/ws")
            .setAllowedOriginPatterns("*")
            .addInterceptors(new HttpSessionHandshakeInterceptor());

        log.info("WebSocket(Mode: {}) init endpoints success.", mode);
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        // 启动前先删除掉可能存在的残留STOMP连接缓存数据
        // FIXME: 没有考虑多服务部署场景，仅单机模式
        redis.del(RedisKey.STOMP_ONLINE_USERS);
        log.info("WebSocket(Mode: {}) clear online user info cache of redis.", mode);

        registry.setPreservePublishOrder(true);
        registry.setUserDestinationPrefix("/user");
        registry.setApplicationDestinationPrefixes("/app");

        String[] destinationPrefixes = { "/topic", "/queue" };
        if (stompMode) {
            // 1. 使用 RabbitMQ 处理消息（需要安装 STOMP 插件）
            RabbitProperties rabbitProperties = SpringUtil.getBean(RabbitProperties.class);
            registry
                .enableStompBrokerRelay(destinationPrefixes)
                .setRelayPort(stompPort)
                .setRelayHost(rabbitProperties.getHost())
                .setVirtualHost(rabbitProperties.getVirtualHost())
                .setClientLogin(rabbitProperties.getUsername())
                .setClientPasscode(rabbitProperties.getPassword())
                .setSystemLogin(rabbitProperties.getUsername())
                .setSystemPasscode(rabbitProperties.getPassword());
        } else {
            // 2. 使用内存方式处理消息
            registry.enableSimpleBroker(destinationPrefixes);
        }

        log.info("WebSocket(Mode: {}) init messageBroker success.", mode);
    }

    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        registration.interceptors(clientInboundChannelInterceptor);
    }

    @Override
    public void configureClientOutboundChannel(ChannelRegistration registration) {
        registration.interceptors(clientOutboundChannelInterceptor);
    }
}
