/* (C) 2022 YiRing, Inc. */
package com.yiring.websocket.registry;

import com.yiring.websocket.domain.StompPrincipal;
import com.yiring.websocket.interceptor.AbstractMessageHandler;
import java.security.Principal;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;

/**
 * 自定义STOMP在线用户信息统计与操作
 *
 * @author ifzm
 * @version 0.1
 * 2019/10/10 21:19
 */

@Component
@RequiredArgsConstructor
public class CustomStompUserRegistry implements StompUserRegistry, SmartApplicationListener {

    final AbstractMessageHandler handler;

    /**
     * sessionId, Principal
     */
    private final Map<String, StompPrincipal> users = new ConcurrentHashMap<>();

    private final Object lock = new Object();

    @Override
    public boolean supportsEventType(@NonNull Class<? extends ApplicationEvent> eventType) {
        return AbstractSubProtocolEvent.class.isAssignableFrom(eventType);
    }

    @Override
    public void onApplicationEvent(@NonNull ApplicationEvent event) {
        AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
        Message<?> message = subProtocolEvent.getMessage();

        if (event instanceof SessionConnectedEvent || event instanceof SessionDisconnectEvent) {
            SimpMessageHeaderAccessor accessor = handler.getAccessor(message);
            Assert.state(accessor != null, "No accessor");

            String sessionId = accessor.getSessionId();
            Assert.state(sessionId != null, "No session id");

            if (event instanceof SessionConnectedEvent) {
                Principal user = subProtocolEvent.getUser();
                synchronized (lock) {
                    this.users.put(sessionId, (StompPrincipal) user);
                }
            } else {
                synchronized (lock) {
                    this.users.remove(sessionId);
                }
            }
        }
    }

    @Override
    public Set<StompPrincipal> getUsers() {
        return new HashSet<>(this.users.values());
    }

    @Override
    public int getUserCount() {
        return this.users.size();
    }

    @Override
    public StompPrincipal getUser(String sessionId) {
        return this.users.get(sessionId);
    }

    @Override
    public void updateUser(String sessionId, StompPrincipal principal) {
        synchronized (lock) {
            if (this.users.containsKey(sessionId)) {
                this.users.put(sessionId, principal);
            }
        }
    }
}
