package org.springframework.web.socket.messaging;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.aspectj.weaver.Constants;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

/* loaded from: input_file:BOOT-INF/lib/spring-websocket-4.0.5.RELEASE.jar:org/springframework/web/socket/messaging/StompSubProtocolHandler.class */
public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
    public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16640;
    public static final String CONNECTED_USER_HEADER = "user-name";
    private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
    private int messageSizeLimit = 65536;
    private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap();
    private final StompEncoder stompEncoder = new StompEncoder();
    private UserSessionRegistry userSessionRegistry;
    private ApplicationEventPublisher eventPublisher;

    public void setMessageSizeLimit(int i) {
        this.messageSizeLimit = i;
    }

    public int getMessageSizeLimit() {
        return this.messageSizeLimit;
    }

    public void setUserSessionRegistry(UserSessionRegistry userSessionRegistry) {
        this.userSessionRegistry = userSessionRegistry;
    }

    public UserSessionRegistry getUserSessionRegistry() {
        return this.userSessionRegistry;
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public List<String> getSupportedProtocols() {
        return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
    }

    @Override // org.springframework.context.ApplicationEventPublisherAware
    public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
        this.eventPublisher = applicationEventPublisher;
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void handleMessageFromClient(WebSocketSession webSocketSession, WebSocketMessage<?> webSocketMessage, MessageChannel messageChannel) {
        try {
            Assert.isInstanceOf(TextMessage.class, webSocketMessage);
            ByteBuffer wrap = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
            BufferingStompDecoder bufferingStompDecoder = this.decoders.get(webSocketSession.getId());
            if (bufferingStompDecoder == null) {
                throw new IllegalStateException("No decoder for session id '" + webSocketSession.getId() + "'");
            }
            List<Message> decode = bufferingStompDecoder.decode(wrap);
            if (decode.isEmpty()) {
                logger.debug("Incomplete STOMP frame content received,buffered=" + bufferingStompDecoder.getBufferSize() + ", buffer size limit=" + bufferingStompDecoder.getBufferSizeLimit());
                return;
            }
            for (Message message : decode) {
                try {
                    StompHeaderAccessor wrap2 = StompHeaderAccessor.wrap(message);
                    if (logger.isTraceEnabled()) {
                        if (SimpMessageType.HEARTBEAT.equals(wrap2.getMessageType())) {
                            logger.trace("Received heartbeat from client session=" + webSocketSession.getId());
                        } else {
                            logger.trace("Received message from client session=" + webSocketSession.getId());
                        }
                    }
                    wrap2.setSessionId(webSocketSession.getId());
                    wrap2.setSessionAttributes(webSocketSession.getAttributes());
                    wrap2.setUser(webSocketSession.getPrincipal());
                    Message build = MessageBuilder.withPayload(message.getPayload()).setHeaders(wrap2).build();
                    if (this.eventPublisher != null && StompCommand.CONNECT.equals(wrap2.getCommand())) {
                        publishEvent(new SessionConnectEvent(this, build));
                    }
                    messageChannel.send(build);
                } catch (Throwable th) {
                    logger.error("Terminating STOMP session due to failure to send message", th);
                    sendErrorMessage(webSocketSession, th);
                }
            }
        } catch (Throwable th2) {
            logger.error("Failed to parse WebSocket message to STOMP frame(s)", th2);
            sendErrorMessage(webSocketSession, th2);
        }
    }

    private void publishEvent(ApplicationEvent applicationEvent) {
        try {
            this.eventPublisher.publishEvent(applicationEvent);
        } catch (Throwable th) {
            logger.error("Error while publishing " + applicationEvent, th);
        }
    }

    protected void sendErrorMessage(WebSocketSession webSocketSession, Throwable th) {
        StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.ERROR);
        create.setMessage(th.getMessage());
        try {
            webSocketSession.sendMessage(new TextMessage(this.stompEncoder.encode(MessageBuilder.withPayload(new byte[0]).setHeaders(create).build())));
        } catch (Throwable th2) {
        }
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void handleMessageToClient(WebSocketSession webSocketSession, Message<?> message) {
        StompHeaderAccessor wrap = StompHeaderAccessor.wrap(message);
        if (wrap.getMessageType() == SimpMessageType.CONNECT_ACK) {
            StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.CONNECTED);
            create.setVersion(getVersion(wrap));
            create.setHeartbeat(0L, 0L);
            wrap = create;
        } else if (SimpMessageType.MESSAGE.equals(wrap.getMessageType())) {
            wrap.updateStompCommandAsServerMessage();
        }
        if (wrap.getCommand() == StompCommand.CONNECTED) {
            afterStompSessionConnected(wrap, webSocketSession);
        }
        if (StompCommand.MESSAGE.equals(wrap.getCommand())) {
            if (wrap.getSubscriptionId() == null) {
                logger.error("Ignoring message, no subscriptionId header: " + message);
                return;
            } else {
                String firstNativeHeader = wrap.getFirstNativeHeader("subscribeDestination");
                if (firstNativeHeader != null) {
                    wrap.setDestination(firstNativeHeader);
                }
            }
        }
        try {
            if (!(message.getPayload() instanceof byte[])) {
                logger.error("Ignoring message, expected byte[] content: " + message);
                return;
            }
            try {
                try {
                    Message build = MessageBuilder.withPayload(message.getPayload()).setHeaders(wrap).build();
                    if (this.eventPublisher != null && StompCommand.CONNECTED.equals(wrap.getCommand())) {
                        publishEvent(new SessionConnectedEvent(this, build));
                    }
                    webSocketSession.sendMessage(new TextMessage(this.stompEncoder.encode(build)));
                    if (StompCommand.ERROR.equals(wrap.getCommand())) {
                        try {
                            webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                        } catch (IOException e) {
                        }
                    }
                } catch (Throwable th) {
                    sendErrorMessage(webSocketSession, th);
                    if (StompCommand.ERROR.equals(wrap.getCommand())) {
                        try {
                            webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                        } catch (IOException e2) {
                        }
                    }
                }
            } catch (SessionLimitExceededException e3) {
                throw e3;
            }
        } catch (Throwable th2) {
            if (StompCommand.ERROR.equals(wrap.getCommand())) {
                try {
                    webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                } catch (IOException e4) {
                }
            }
            throw th2;
        }
    }

    private String getVersion(StompHeaderAccessor stompHeaderAccessor) {
        Message message = (Message) stompHeaderAccessor.getHeader("simpConnectMessage");
        StompHeaderAccessor wrap = StompHeaderAccessor.wrap(message);
        Assert.notNull(message, "CONNECT_ACK does not contain original CONNECT " + stompHeaderAccessor);
        Set acceptVersion = wrap.getAcceptVersion();
        if (acceptVersion.contains(Constants.RUNTIME_LEVEL_12)) {
            return Constants.RUNTIME_LEVEL_12;
        }
        if (acceptVersion.contains("1.1")) {
            return "1.1";
        }
        if (acceptVersion.isEmpty()) {
            return null;
        }
        throw new StompConversionException("Unsupported version '" + acceptVersion + "'");
    }

    private void afterStompSessionConnected(StompHeaderAccessor stompHeaderAccessor, WebSocketSession webSocketSession) {
        Principal principal = webSocketSession.getPrincipal();
        if (principal != null) {
            stompHeaderAccessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
            if (this.userSessionRegistry != null) {
                this.userSessionRegistry.registerSessionId(resolveNameForUserSessionRegistry(principal), webSocketSession.getId());
            }
        }
        if (stompHeaderAccessor.getHeartbeat()[1] > 0) {
            WebSocketSession unwrap = WebSocketSessionDecorator.unwrap(webSocketSession);
            if (unwrap instanceof SockJsSession) {
                logger.debug("STOMP heartbeats negotiated, disabling SockJS heartbeats.");
                ((SockJsSession) unwrap).disableHeartbeat();
            }
        }
    }

    private String resolveNameForUserSessionRegistry(Principal principal) {
        String name = principal.getName();
        if (principal instanceof DestinationUserNameProvider) {
            name = ((DestinationUserNameProvider) principal).getDestinationUserName();
        }
        return name;
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public String resolveSessionId(Message<?> message) {
        return StompHeaderAccessor.wrap(message).getSessionId();
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void afterSessionStarted(WebSocketSession webSocketSession, MessageChannel messageChannel) {
        if (webSocketSession.getTextMessageSizeLimit() < 16640) {
            webSocketSession.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE);
        }
        this.decoders.put(webSocketSession.getId(), new BufferingStompDecoder(getMessageSizeLimit()));
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void afterSessionEnded(WebSocketSession webSocketSession, CloseStatus closeStatus, MessageChannel messageChannel) {
        this.decoders.remove(webSocketSession.getId());
        Principal principal = webSocketSession.getPrincipal();
        if (this.userSessionRegistry != null && principal != null) {
            this.userSessionRegistry.unregisterSessionId(resolveNameForUserSessionRegistry(principal), webSocketSession.getId());
        }
        if (logger.isDebugEnabled()) {
            logger.debug("WebSocket session ended, sending DISCONNECT message to broker");
        }
        StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.DISCONNECT);
        create.setSessionId(webSocketSession.getId());
        Message build = MessageBuilder.withPayload(new byte[0]).setHeaders(create).build();
        if (this.eventPublisher != null) {
            publishEvent(new SessionDisconnectEvent(this, webSocketSession.getId(), closeStatus));
        }
        messageChannel.send(build);
    }
}
