package com.bcxin.web.commons.components;

import com.bcxin.saas.core.components.ThreadContextManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.awt.image.ImageProducer;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
 * 使用singleton是为了单例的模式
 */
@Component
public class ThreadContextManagerImpl implements ThreadContextManager {
    @Resource
    private HttpServletRequest servletRequest;
    private static final Logger logger = LoggerFactory.getLogger(ThreadContextManagerImpl.class);
    private static final ThreadLocal<Map<String,Object>> threadManager = new InheritableThreadLocal<>();
    static {
        threadManager.set(new HashMap<>());
    }

    @Override
    public <T> void store(String key, T instance) {
        Map<String, Object> mp = getMap();
        if (instance == null) {
            if (mp.containsKey(key)) {
                mp.remove(key);
            }
        } else {
            mp.put(key, instance);
        }

        setMap(mp);
    }

    @Override
    public <T> T get(String key) {
        Map<String, Object> mp = getMap();
        return (T) mp.get(key);
    }

    @Override
    public <T> T get(String key, Supplier<T> supplier,Class<T> tClass) {
        T instance = get(key);
        try {
            boolean createNew = false;
            if (instance == null) {
                createNew = true;
            }

            if(!createNew && tClass.isAssignableFrom(Connection.class)) {
                Connection conn = (Connection) instance;
                if (conn != null && conn.isClosed()) {
                    createNew = true;
                }
            }

            if (createNew) {
                instance = supplier.get();

                store(key, instance);
            }
        } catch (Exception ex) {
            logger.error("获取实例信息（{}）发生异常", key, ex);
        }

        return instance;
    }

    @Override
    public void clear() throws SQLException {
        Map<String, Object> mp = getMap();
        Collection<AutoCloseable> closeables =
                mp.values().stream().filter(ii -> ii instanceof AutoCloseable)
                        .map(ii -> (AutoCloseable) ii)
                        .collect(Collectors.toList());

        /**
         * 释放掉可释放的资源对下
         */
        for (AutoCloseable closeable : closeables) {
            try {
                closeable.close();
            } catch (Exception ex) {
                logger.error("资源释放发生异常：{}", closeable.getClass(), ex);
            }
        }

        if (mp != null) {
            mp.clear();
            setMap(mp);
        }
    }

    @Override
    public void remove(String key) {
        store(key, null);
    }

    @Override
    public boolean isWebRequest() {
        try {
            return servletRequest != null && StringUtils.hasLength(servletRequest.getRequestURI());
        } catch (Exception ex) {

        }
        return false;
    }

    private static final String CURRENT_THREAD_KEY="CURRENT_THREAD_KEY.RQ";
    private Map<String,Object> getMap() {
        if(isWebRequest()) {
            try {
                Map<String, Object> mp = (Map<String, Object>) servletRequest.getAttribute(CURRENT_THREAD_KEY);
                if (mp == null) {
                    mp = new HashMap<>();
                    setMap(mp);
                }

                return mp;
            } catch (Exception ex) {
                logger.error("Failed to get map in servletRequest", ex);
            }
        }

        return threadManager.get();
    }

    private void setMap(Map<String,Object> mp) {
        if(isWebRequest()) {
            try {
                servletRequest.setAttribute(CURRENT_THREAD_KEY, mp);
                return;
            } catch (Exception ex) {
                logger.error("Failed to set map in servletRequest", ex);
            }
        }

        threadManager.set(mp);
    }
}
