package cn.gtmap.landtax.support.jpa;

import org.apache.commons.collections.CollectionUtils;

import javax.persistence.Entity;
import javax.persistence.ManyToOne;
import javax.persistence.OneToOne;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * Created by zhouzhiwei on 2015-11-14.
 * 记录标有CustomFetch注解的Entity信息
 */
public class BaseRepositoryCustomFetch {
    /**
     * 业务名称，用来构建Entity类文件所在的包路径
     * 包路径为"cn.gtmap." + businessName + ".entity"
     * 实体类的包路径不是这样子，可修改构造函数中packageName的赋值
     */
    private String businessName;

    /**
     * 遍历entity下面的类，记录所有实体类中有CustomFetch注解的类名
     * (String)key：带@Entity注解的类名（SimpleName）
     * (List<String>)value：Entity类中，标有CustomFetch注解的类名（SimpleName）
     */
    private static HashMap<String, List<String>> entityCustomFetchMap;

    /**
     * 初始化entityCustomFetchMap数据
     * @param businessName
     */
    public BaseRepositoryCustomFetch(String businessName) {
        this.businessName = businessName;
        this.entityCustomFetchMap = new HashMap<String, List<String>>();

        // 遍历entity下面的所有类
        String packageName = "cn.gtmap." + businessName + ".entity";
        Set<String> classNames = getClassName(packageName, false);
        if (classNames != null) {
            for (String className : classNames) {
                try {
                    Class clazz = Class.forName(className).newInstance().getClass();

                    // 判断class是否有@Entity注解，有才做下一步判断
                    if (clazz.isAnnotationPresent(Entity.class)) {
                        // 遍历class所有field
                        List<Field> fieldList = getEntityFieldList(clazz);
                        for (Field field : fieldList) {
                            // 获取属性的类型，判断如果不是packageName包里面的，则不管
                            String type = field.getGenericType().toString();
                            if (!type.startsWith("class " + packageName))
                                continue;

                            // 判断是否有CustomFetch注解，若有则存到entityCustomFetchMap中
                            if (field.isAnnotationPresent(CustomFetch.class)) {
                                // 必须是定义在ManyToOne或者OneToOne对应关系的实体类上
                                if (field.isAnnotationPresent(ManyToOne.class) || field.isAnnotationPresent(OneToOne.class)) {
                                    setEntityCustomFetchMap(clazz.getSimpleName(), field.getType().getSimpleName());
                                }
                            }
                        }
                    }
                } catch (ClassNotFoundException e) {
                } catch (InstantiationException e) {
                } catch (IllegalAccessException e) {
                }
            }
        }
        entityCustomFetchMap.size();
    }

    /**
     * 根据class遍历获取所有的field属性
     * @param clazz
     * @return
     */
    private static List<Field> getEntityFieldList(Class clazz) {
        List<Field> fieldList = new ArrayList<Field>();
        Field[] fieldArr = null;

        // 获取类的定义字段
        fieldArr = clazz.getDeclaredFields();
        fieldList.addAll(Arrays.asList(fieldArr));

        // 获取继承类的字段
        Class superClazz = clazz.getSuperclass();
        while (superClazz != null) {
            fieldArr = superClazz.getDeclaredFields();
            fieldList.addAll(Arrays.asList(fieldArr));
            superClazz = superClazz.getSuperclass();
        }

        return  fieldList;
    }

    /**
     * 将定义在该实体类中的需要Fetch的从表实体类，加到entityCustomFetchMap中
     * @param entityName
     * @param fetchEntityName
     */
    private void setEntityCustomFetchMap(String entityName, String fetchEntityName) {
        // 如果不存在这个entityName的map，则put进去一个空的
        if (CollectionUtils.isEmpty(entityCustomFetchMap.get(entityName))) {
            entityCustomFetchMap.remove(entityName);
            entityCustomFetchMap.put(entityName, new ArrayList<String>());
        }

        // 不存在，则加入
        if (entityCustomFetchMap.get(entityName).indexOf(fetchEntityName) == -1) {
            entityCustomFetchMap.get(entityName).add(fetchEntityName);
        }
    }

    /**
     * 根据实体类名称，获取定义在该实体类中的需要Fetch的从表实体类列表
     * @param entityName
     * @return
     */
    public List<String> getSubEntityFetchMap(String entityName) {
        if (entityCustomFetchMap == null) {
            return null;
        } else {
            return entityCustomFetchMap.get(entityName);
        }
    }

    /**
     * 获取某包下所有类
     * @param packageName 包名
     * @param isRecursion 是否遍历子包
     * @return 类的完整名称
     */
    public static Set<String> getClassName(String packageName, boolean isRecursion) {
        Set<String> classNames = null;
        ClassLoader loader = Thread.currentThread().getContextClassLoader();
        String packagePath = packageName.replace(".", "/");

        URL url = loader.getResource(packagePath);
        if (url != null) {
            String protocol = url.getProtocol();
            if (protocol.equals("file")) {
                classNames = getClassNameFromDir(url.getPath(), packageName, isRecursion);
            } else if (protocol.equals("jar")) {
                JarFile jarFile = null;
                try{
                    jarFile = ((JarURLConnection) url.openConnection()).getJarFile();
                } catch(Exception e){
                    e.printStackTrace();
                }

                if(jarFile != null){
                    getClassNameFromJar(jarFile.entries(), packageName, isRecursion);
                }
            }
        } else {
            /*从所有的jar包中查找包名*/
            classNames = getClassNameFromJars(((URLClassLoader)loader).getURLs(), packageName, isRecursion);
        }

        return classNames;
    }

    /**
     * 从项目文件获取某包下所有类
     * @param filePath 文件路径
     * @param packageName 类名集合
     * @param isRecursion 是否遍历子包
     * @return 类的完整名称
     */
    private static Set<String> getClassNameFromDir(String filePath, String packageName, boolean isRecursion) {
        Set<String> className = new HashSet<String>();
        File file = new File(filePath);
        File[] files = file.listFiles();
        for (File childFile : files) {
            if (childFile.isDirectory()) {
                if (isRecursion) {
                    className.addAll(getClassNameFromDir(childFile.getPath(), packageName+"."+childFile.getName(), isRecursion));
                }
            } else {
                String fileName = childFile.getName();
                if (fileName.endsWith(".class") && !fileName.contains("$")) {
                    className.add(packageName+ "." + fileName.replace(".class", ""));
                }
            }
        }

        return className;
    }

    /**
     * @param jarEntries
     * @param packageName
     * @param isRecursion
     * @return
     */
    private static Set<String> getClassNameFromJar(Enumeration<JarEntry> jarEntries, String packageName, boolean isRecursion){
        Set<String> classNames = new HashSet<String>();

        while (jarEntries.hasMoreElements()) {
            JarEntry jarEntry = jarEntries.nextElement();
            if(!jarEntry.isDirectory()){
                /*
                 * 这里是为了方便，先把"/" 转成 "." 再判断 ".class" 的做法可能会有bug
                 * (FIXME: 先把"/" 转成 "." 再判断 ".class" 的做法可能会有bug)
                 */
                String entryName = jarEntry.getName().replace("/", ".");
                if (entryName.endsWith(".class") && !entryName.contains("$") && entryName.startsWith(packageName)) {
                    entryName = entryName.replace(".class", "");
                    if(isRecursion){
                        classNames.add(entryName);
                    } else if(!entryName.replace(packageName+".", "").contains(".")){
                        classNames.add(entryName);
                    }
                }
            }
        }

        return classNames;
    }

    /**
     * 从所有jar中搜索该包，并获取该包下所有类
     * @param urls URL集合
     * @param packageName 包路径
     * @param isRecursion 是否遍历子包
     * @return 类的完整名称
     */
    private static Set<String> getClassNameFromJars(URL[] urls, String packageName, boolean isRecursion) {
        Set<String> classNames = new HashSet<String>();

        for (int i = 0; i < urls.length; i++) {
            String classPath = urls[i].getPath();

            //不必搜索classes文件夹
            if (classPath.endsWith("classes/")) {continue;}

            JarFile jarFile = null;
            try {
                jarFile = new JarFile(classPath.substring(classPath.indexOf("/")));
            } catch (IOException e) {
                e.printStackTrace();
            }

            if (jarFile != null) {
                classNames.addAll(getClassNameFromJar(jarFile.entries(), packageName, isRecursion));
            }
        }

        return classNames;
    }

    public String getBusinessName() {
        return businessName;
    }

    public void setBusinessName(String businessName) {
        this.businessName = businessName;
    }
}
