/* (C) 2022 YiRing, Inc. */
package com.yiring.websocket.interceptor;

import cn.hutool.core.convert.Convert;
import com.yiring.auth.domain.user.User;
import com.yiring.auth.util.Auths;
import com.yiring.common.core.Redis;
import com.yiring.websocket.config.WebSocketStompConfig;
import com.yiring.websocket.constant.RedisKey;
import com.yiring.websocket.domain.StompPrincipal;
import java.util.Collection;
import java.util.Map;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;

/**
 * ClientInboundChannelInterceptor
 * 接收客户端消息的拦截器
 *
 * @author ifzm
 * @version 0.1
 * 2019/9/28 20:58
 */

@Slf4j
@Component
@RequiredArgsConstructor
public class ClientInboundChannelInterceptor implements ChannelInterceptor {

    final Redis redis;
    final Auths auths;
    final AbstractMessageHandler handler;

    private final Object lock = new Object();

    @Override
    public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
        SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
        Assert.state(accessor != null, "No accessor");

        if (handler.isConnect(accessor)) {
            Object raw = message.getHeaders().get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
            if (raw instanceof Map) {
                StompPrincipal principal = new StompPrincipal();
                principal.setSession(accessor.getSessionId());

                Object tokens = ((Map<?, ?>) raw).get("token");
                if (tokens instanceof Collection<?>) {
                    try {
                        String token = Convert.toList(String.class, tokens).get(0);
                        User user = auths.getUserByToken(token);
                        principal.setToken(token);
                        principal.setUser(user.getRealName());
                        principal.setType(StompPrincipal.Type.LOGIN_USER);
                    } catch (Exception e) {
                        log.warn("WebSocket(Mode: {}) connect failed: {}", WebSocketStompConfig.mode, e.getMessage());
                        throw e;
                    }
                } else {
                    principal.setUser("Guest." + principal.getSession());
                    principal.setType(StompPrincipal.Type.GUEST_USER);
                }

                accessor.setUser(principal);
                synchronized (lock) {
                    redis.hset(RedisKey.STOMP_ONLINE_USERS, principal.getSession(), principal);
                    log.info(
                        "WebSocket(Mode: {}) Online Users: {} (incr: +1, user: {}, session: {}, token: {})",
                        WebSocketStompConfig.mode,
                        redis.hsize(RedisKey.STOMP_ONLINE_USERS),
                        principal.getUser(),
                        principal.getSession(),
                        principal.getToken()
                    );
                }
            }
        } else if (handler.isDisconnect(accessor)) {
            StompPrincipal principal = (StompPrincipal) accessor.getUser();
            if (principal != null && !message.getHeaders().containsKey(SimpMessageHeaderAccessor.HEART_BEAT_HEADER)) {
                synchronized (lock) {
                    redis.hdel(RedisKey.STOMP_ONLINE_USERS, principal.getSession());
                    log.info(
                        "WebSocket(Mode: {}) Online Users: {} (incr: -1, user: {}, session: {}, token: {})",
                        WebSocketStompConfig.mode,
                        redis.hsize(RedisKey.STOMP_ONLINE_USERS),
                        principal.getUser(),
                        principal.getSession(),
                        principal.getToken()
                    );
                }
            }
        }

        return message;
    }
}
