扫描指定包下所有类
import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.core.io.support.ResourcePatternUtils; import org.springframework.core.type.classreading.MetadataReader; import org.springframework.core.type.classreading.MetadataReaderFactory; import org.springframework.core.type.classreading.SimpleMetadataReaderFactory; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.RestController; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class App { private static final Logger logger = LoggerFactory.getLogger(App.class); private static final String RESOURCE_PATTERN = "**/*.class"; public static void main(String[] args) throws Exception { System.out.println(scanPackages("com.g2.xx.trade.svr.rest".split(","))); } static List<Class<?>> scanPackages(String[] basePackages) { List<Class<?>> candidates = new ArrayList<Class<?>>(); for (String pkg : basePackages) { try { candidates.addAll(findCandidateClasses(pkg)); } catch (IOException e) { logger.error("扫描指定注解@RestController的基础包{}时出现异常", pkg); continue; } } return candidates; } /** * 获取符合要求的Controller名称 * * @param basePackage * @return * @throws IOException */ private static List<Class<?>> findCandidateClasses(String basePackage) throws IOException { if (logger.isDebugEnabled()) { logger.debug("开始扫描指定包{}下的所有类" + basePackage); } List<Class<?>> candidates = new ArrayList<Class<?>>(); String packageSearchPath = replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN; ResourceLoader resourceLoader = new DefaultResourceLoader(); MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader); Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath); for (Resource resource : resources) { MetadataReader reader = readerFactory.getMetadataReader(resource); Class<?> candidateClass = transform(reader.getClassMetadata().getClassName()); if (candidateClass == null) { continue; } RestController alias = candidateClass.getAnnotation(RestController.class); if (alias == null) { continue; } candidates.add(candidateClass); logger.debug("扫描到@RestController注解基础类:{}" + candidateClass.getName()); } return candidates; } /** * 用"/"替换包路径中"." * * @param path * @return */ private static String replaceDotByDelimiter(String path) { return StringUtils.replace(path, ".", "/"); } /** * @param className * @return */ private static Class<?> transform(String className) { Class<?> clazz = null; try { clazz = ClassUtils.forName(className, App.class.getClassLoader()); } catch (ClassNotFoundException e) { logger.error("未找到指定类:{}", className); } return clazz; } }