扫描指定包下所有类

import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.classreading.SimpleMetadataReaderFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;


public class App  {
    private static final Logger logger = LoggerFactory.getLogger(App.class);
    private static final String RESOURCE_PATTERN = "**/*.class";

    public static void main(String[] args) throws Exception {

        System.out.println(scanPackages("com.g2.xx.trade.svr.rest".split(",")));
    }

    static List<Class<?>> scanPackages(String[] basePackages) {
        List<Class<?>> candidates = new ArrayList<Class<?>>();
        for (String pkg : basePackages) {
            try {
                candidates.addAll(findCandidateClasses(pkg));
            } catch (IOException e) {
                logger.error("扫描指定注解@RestController的基础包{}时出现异常", pkg);
                continue;
            }
        }
        return candidates;
    }


    /**
     * 获取符合要求的Controller名称
     *
     * @param basePackage
     * @return
     * @throws IOException
     */
    private static List<Class<?>> findCandidateClasses(String basePackage) throws IOException {
        if (logger.isDebugEnabled()) {
            logger.debug("开始扫描指定包{}下的所有类" + basePackage);
        }
        List<Class<?>> candidates = new ArrayList<Class<?>>();
        String packageSearchPath = replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN;
        ResourceLoader resourceLoader = new DefaultResourceLoader();
        MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader);
        Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath);
        for (Resource resource : resources) {
            MetadataReader reader = readerFactory.getMetadataReader(resource);
            Class<?> candidateClass = transform(reader.getClassMetadata().getClassName());
            if (candidateClass == null) {
                continue;
            }
            RestController alias = candidateClass.getAnnotation(RestController.class);
            if (alias == null) {
                continue;
            }
            candidates.add(candidateClass);
            logger.debug("扫描到@RestController注解基础类:{}" + candidateClass.getName());
        }
        return candidates;
    }

    /**
     * 用"/"替换包路径中"."
     *
     * @param path
     * @return
     */
    private static String replaceDotByDelimiter(String path) {
        return StringUtils.replace(path, ".", "/");
    }

    /**
     * @param className
     * @return
     */
    private static Class<?> transform(String className) {
        Class<?> clazz = null;
        try {
            clazz = ClassUtils.forName(className, App.class.getClassLoader());
        } catch (ClassNotFoundException e) {
            logger.error("未找到指定类:{}", className);
        }
        return clazz;
    }

}

 

posted @ 2022-05-21 22:38  zslm___  阅读(114)  评论(0编辑  收藏  举报