package com.bcxin.tenant.open.rest.apis.components;

import com.alibaba.fastjson.JSONObject;
import com.bcxin.tenant.open.dubbo.common.configs.filters.AuthTokenFilter;
import com.bcxin.tenant.open.infrastructures.TenantContext;
import com.bcxin.tenant.open.infrastructures.TenantEmployeeContext;
import com.bcxin.tenant.open.infrastructures.components.JsonProvider;
import com.bcxin.tenant.open.infrastructures.exceptions.NotSupportTenantException;
import com.bcxin.tenant.open.infrastructures.exceptions.UnAuthorizedTenantException;
import com.bcxin.tenant.open.infrastructures.utils.ExceptionUtil;
import com.bcxin.tenant.open.jdks.PoliceIncidentsRpcProvider;
import com.bcxin.tenant.open.jdks.QueueRpcProvider;
import com.bcxin.tenant.open.jdks.requests.PoliceIncidentStatisticsRequest;
import com.bcxin.tenant.open.jdks.requests.SyncUpdateGeoRequest;
import com.bcxin.tenant.open.jdks.requests.enums.BroadcastMessageType;
import com.bcxin.tenant.open.jdks.responses.PoliceIncidentsLevelCountResponse;
import com.bcxin.tenant.open.rest.apis.dtos.RefreshPublishMessageResponse;
import com.bcxin.tenant.open.rest.apis.dtos.SocketSessionDTO;
import jakarta.websocket.*;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * http://www.websocket-test.com/
 * 参考: https://blog.csdn.net/w1014074794/article/details/113845879
 */
@Component
@ServerEndpoint("/websocket/{socket_category}")
public class WebSocketServer implements BeanFactoryAware {
    private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);
    private static final String SOCKET_CURRENT_USER = "SOCKET.CURRENT_USER";
    private static final String APP_LOCATION = "app_location";
    private static final Collection<String> ALLOWED_SOCKET_TOPICS = Stream.of("connect","police_incidents", APP_LOCATION).collect(Collectors.toList());
    /**
     * 以通道名称为key，连接会话为对象保存起来
     */
    public static Map<String, Collection<SocketSessionDTO>> websocketClients = new ConcurrentHashMap<String, Collection<SocketSessionDTO>>();

    /**
     * 发送消息到指定连接
     *
     * @param socketname 连接名
     * @param jsonString 消息
     */
    public static void sendMessage(String socketname, String jsonString) {
        Collection<SocketSessionDTO> socketSessions = websocketClients.get(socketname);
        if (socketSessions != null) {
            try {
                for (SocketSessionDTO socketSession : socketSessions) {
                    socketSession.getSession().getBasicRemote().sendText(jsonString);
                }
            } catch (Exception ex) {
                logger.error("发送消息：{}，{} 异常-{}", socketname, jsonString, ExceptionUtil.getStackMessage(ex));
            }
        }
    }

    @OnOpen
    public void onOpen(@PathParam("socket_category") String socket_category, Session session) {
        TenantEmployeeContext.TenantUserModel userModel = TenantContext.getInstance().getUserContext().get();
        if (userModel == null) {
            throw new UnAuthorizedTenantException("暂无授权; 无法访问");
        }

        if (!ALLOWED_SOCKET_TOPICS.contains(socket_category)) {
            logger.error("仅支持{}的socket连接", ALLOWED_SOCKET_TOPICS.stream().collect(Collectors.joining(",")));
            throw new NotSupportTenantException(String.format("仅支持%s的socket连接", ALLOWED_SOCKET_TOPICS.stream().collect(Collectors.joining(","))));
        }

        session.getUserProperties().put(SOCKET_CURRENT_USER, userModel);
        Collection<SocketSessionDTO> socketSessions = websocketClients.get(socket_category);
        if (socketSessions == null) {
            socketSessions = Collections.synchronizedCollection(new ArrayList<>());
        }

        socketSessions.add(
                SocketSessionDTO.create(
                        socket_category,
                        userModel.getEmployeeId(),
                        userModel.getTencentUserId(),
                        userModel.getAssignedSuperviseDepartIds(),
                        session)
        );

        websocketClients.put(socket_category, socketSessions);
    }

    @OnError
    public void onError(Session session, Throwable error) {
        logger.info("服务端发生了错误" + error.getMessage());
    }


    /**
     * 连接关闭
     */
    @OnClose
    public void onClose(@PathParam("socketname") String socketname, Session session) {
        if (StringUtils.hasLength(socketname)) {
            Collection<SocketSessionDTO> socketSessions = websocketClients.get(socketname);
            if (socketSessions != null) {
                Collection<SocketSessionDTO> removedSocketSessions =
                        socketSessions.stream().filter(ix -> ix.getSession() == session).collect(Collectors.toList());
                socketSessions.removeAll(removedSocketSessions);

                websocketClients.put(socketname, socketSessions);
            } else {
                websocketClients.remove(socketname);
            }
        }
    }

    /**
     * 收到客户端的消息
     *
     * @param message 消息
     * @param session 会话
     */
    @OnMessage
    public void onMessage(@PathParam("socket_category") String socket_category, String message, Session session) {
        //logger.error("当前收到了消息：" + message);
        if (StringUtils.hasLength(socket_category)) {
            TenantEmployeeContext.TenantUserModel userModel = (TenantEmployeeContext.TenantUserModel) session.getUserProperties().get(SOCKET_CURRENT_USER);
            if (userModel == null) {
                userModel = TenantContext.getInstance().getUserContext().get();
            }

            if (userModel != null) {
                if (APP_LOCATION.equalsIgnoreCase(socket_category)) {
                    QueueRpcProvider queueRpcProvider = this.beanFactory.getBean(QueueRpcProvider.class);
                    JsonProvider jsonProvider = this.beanFactory.getBean(JsonProvider.class);
                    SyncUpdateGeoRequest request = jsonProvider.toObject(SyncUpdateGeoRequest.class, message);
                    request.setEmployeeIds(Collections.singleton(userModel.getEmployeeId()));

                    queueRpcProvider.syncUpdatedLonLat(request);
                } else {
                    String dts = "";
                    if (userModel.getAssignedSuperviseDepartIds() != null) {
                        dts = Arrays.stream(userModel.getAssignedSuperviseDepartIds()).collect(Collectors.joining(","));
                    }
                    synchronized (session) {
                        try {
                            session.getBasicRemote().sendText("来自服务器：" + socket_category + "你(" + dts + ")的消息(" + message + ")我收到啦");
                        } catch (IOException ex) {
                            ex.printStackTrace();
                            logger.error("发送消息异常：{}", ExceptionUtil.getStackMessage(ex));
                        }
                    }
                }
            } else {
                synchronized (session) {
                    try {
                        session.getBasicRemote().sendText("来自服务器：" + socket_category + "你的消息(" + message + ")我收到啦");
                    } catch (IOException ex) {
                        ex.printStackTrace();
                        logger.error("发送消息异常：{}", ExceptionUtil.getStackMessage(ex));
                    }
                }
            }
        }
    }

    /**
     * 向所有连接主动推送消息
     *
     * @param jsonObject 消息体
     * @throws IOException
     */
    public void sendMessageAll(JSONObject jsonObject) throws IOException {
        String jsonData = jsonObject.toJSONString();
        Collection<Collection<SocketSessionDTO>> allClientValues = websocketClients.values();
        if (websocketClients != null && !CollectionUtils.isEmpty(allClientValues)) {
            for (Collection<SocketSessionDTO> socketSessions : allClientValues) {
                StringBuilder sb = new StringBuilder();
                try {
                    socketSessions.parallelStream().forEach(ss -> {
                        sb.append(ss.getDescription());
                        try {
                            ss.getSession().getAsyncRemote().sendText(jsonData);
                        } catch (Exception ex) {
                            sb.append(ex.toString());
                        }
                    });
                } catch (Exception ex) {
                    sb.append(";");
                    sb.append(ex.toString());
                    ex.printStackTrace();
                } finally {
                    logger.error("sendMessageAll-推送广播消息:{};数据={}", sb, jsonData);
                }
            }
        } else {
            logger.error("sendMessageAll-未找到连接客户端:{}", jsonData);
        }
    }

    public void broadcastMessage(RefreshPublishMessageResponse data) {
        if (data == null) {
            return;
        }

        PoliceIncidentsRpcProvider policeIncidentsRpcProvider = this.beanFactory.getBean(PoliceIncidentsRpcProvider.class);
        JsonProvider jsonProvider = this.beanFactory.getBean(JsonProvider.class);
        Collection<SocketSessionDTO> socketSessions =
                getMatchSocketSessions(data.getMessageType(), data.getSuperviseDepartIds());
        socketSessions.parallelStream()
                .forEach(st -> {
                    /**
                     * 通知企业端最新的警情上报等等信息
                     */
                    try {
                        Session session = st.getSession();
                        if (session.isOpen()) {
                            PoliceIncidentsLevelCountResponse pr =
                                    policeIncidentsRpcProvider.getPendingTotal(
                                            PoliceIncidentStatisticsRequest.create(true, st.getSuperviseDepartIds())
                                    );

                            session.getAsyncRemote().sendObject(pr);
                        }
                    } catch (Exception ex) {
                        logger.error("failed to send message at websocket", ex);
                    }
                });
    }

    public Collection<SocketSessionDTO> getMatchSocketSessions(BroadcastMessageType messageType, Collection<String> specialIds) {
        Collection<String> socketNames = websocketClients.keySet().stream().filter(ix -> ALLOWED_SOCKET_TOPICS.contains(ix)).collect(Collectors.toList());
        if (socketNames.isEmpty()) {
            return Collections.EMPTY_LIST;
        }

        logger.error("当前websocketClients的列表为:{}; specialIds ={}",
                websocketClients.keySet().stream().collect(Collectors.joining(",")),
                specialIds.stream().collect(Collectors.joining(","))
        );

        Collection<SocketSessionDTO> socketSessions = new ArrayList<>();
        for (String socketName : socketNames) {
            socketSessions.addAll(websocketClients.get(socketName));
        }

        if (CollectionUtils.isEmpty(socketSessions)) {
            return Collections.EMPTY_LIST;
        }

        Collection<SocketSessionDTO> sessionDTOS = new ArrayList<>();
        switch (messageType) {
            case PoliceIncidents -> sessionDTOS = socketSessions.stream().filter(ix -> ix.contain(specialIds)).collect(Collectors.toList());
            case TRTC -> sessionDTOS = socketSessions.stream().filter(ix -> specialIds.contains(ix.getTencentUserId())).collect(Collectors.toList());
        }

        return sessionDTOS;
    }

    private static BeanFactory beanFactory;

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        this.beanFactory = beanFactory;
    }

    public synchronized void removeCloseSessions(Collection<SocketSessionDTO> removedSessions) {
        if (CollectionUtils.isEmpty(removedSessions)) {
            return;
        }
        removedSessions.stream().forEach(ix ->{
            String socket_category = ix.getCategory();
            if(StringUtils.hasLength(socket_category) && ALLOWED_SOCKET_TOPICS.contains(socket_category)){
                Collection<SocketSessionDTO> socketSessions = websocketClients.get(socket_category);
                if(!CollectionUtils.isEmpty(socketSessions)){
                    socketSessions.remove(ix);
                    logger.warn("移除已关闭的session:{}",ix.getDescription());
                }
            }
        });
    }
}
