用Java手搓一个依赖注入框架

1、bean容器

public class Container {

    private final static Logger log = Logger.getLogger(Container.class.getSimpleName());

    private Map<String, Object> context = new HashMap<>();

    private List<SuspendBean> suspendBeans = new ArrayList<>();

    private List<SuspendBeanMethod> suspendBeanMethods = new ArrayList<>();

    public Container(Class<?> mainClass) {
        String packageName = mainClass.getPackageName();
        init(packageName);
    }

    @SuppressWarnings("unchecked")
    public <T> Optional<T> getBeanByType(Class<T> clazz) {
        return (Optional<T>) context.values().stream()
                .filter(object -> clazz.isAssignableFrom(object.getClass()))
                .findFirst();
    }

    private void init(String... packages) {
        List<Class<?>> classes = Packages.scanClasses(packages);
        System.out.println("扫描到:" + classes.stream().map(Class::getSimpleName).toList());
        for (Class<?> clazz : classes) {
            if (clazz.isAnnotationPresent(Singleton.class)) {
                try {
                    Object object;
                    if (clazz.isRecord()) {
                        createRecordInstance(clazz);
                    } else {
                        object = clazz.getDeclaredConstructor().newInstance();
                        createMethodBean(object);
                        String beanName = getBeanName(clazz);
                        injectField(object);
                        addBean(beanName, object);
                    }
                } catch (InstantiationException | IllegalAccessException | InvocationTargetException |
                         NoSuchMethodException e) {
                    throw new RuntimeException(e);
                }
            }
        }
        initSuspendField();
        initSuspendMethod();
    }

    private void createRecordInstance(Class<?> clazz) throws InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException {
        Class<?>[] constructorParameterTypes = Arrays.stream(clazz.getDeclaredFields()).map(Field::getType).toArray(Class[]::new);
        Field[] fields = clazz.getDeclaredFields();
        Object[] constructorArgValues = new Object[constructorParameterTypes.length];
        for (int i = 0; i < fields.length; i++) {
            Optional<?> optional = this.getBeanByType(constructorParameterTypes[i]);
            if (optional.isPresent()) {
                constructorArgValues[i] = optional.get();
            } else {

            }
        }
        Object recordObject = clazz.getDeclaredConstructor(constructorParameterTypes).newInstance(constructorArgValues);
        String beanName = getBeanName(clazz);
        addBean(beanName, recordObject);
    }

    private static <T> String getBeanName(Class<T> clazz) {
        String beanName;
        if (clazz.isAnnotationPresent(Named.class)) {
            beanName = clazz.getAnnotation(Named.class).value();
        } else {
            beanName = clazz.getSimpleName().substring(0, 1).toLowerCase() + clazz.getSimpleName().substring(1);
        }
        return beanName;
    }

    private void initSuspendField() {
        // 对挂起的bean进行初始化
        log.info("开始对未初始化完的bean的field进行依赖注入:");
        for (SuspendBean suspendBean : suspendBeans) {
            Object object = suspendBean.beanObject();
            Field field = suspendBean.field();
            try {
                field.setAccessible(true);
                if (field.get(object) == null) {
                    String injectBeanName;
                    if (field.isAnnotationPresent(Named.class)) {
                        injectBeanName = field.getAnnotation(Named.class).value();
                    } else {
                        injectBeanName = field.getType().getSimpleName().substring(0, 1).toLowerCase() + field.getType().getSimpleName().substring(1);
                    }
                    field.set(object, context.get(injectBeanName));
                }
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private void initSuspendMethod() {
        // 对挂起的bean method进行初始化
        log.info("开始对未初始化完的bean的method进行依赖注入:");
        for (SuspendBeanMethod suspendBean : suspendBeanMethods) {
            Object object = suspendBean.beanObject();
            Method method = suspendBean.method();
            try {
                List<Object> parameterBeans = new ArrayList<>();
                boolean breakFlag = false;
                for (Parameter parameter : method.getParameters()) {
                    Optional<?> optional = this.getBeanByType(parameter.getType());
                    if (optional.isPresent()) {
                        parameterBeans.add(optional.get());
                    } else {
                        log.severe("无法创建实例:%s".formatted(method.getName()));
                        breakFlag = true;
                    }
                }
                if (breakFlag) {
                    continue;
                }
                Object methodBean = method.invoke(object, parameterBeans.toArray());
                String methodBeanName;
                if (method.isAnnotationPresent(Named.class)) {
                    methodBeanName = method.getAnnotation(Named.class).value();
                } else {
                    methodBeanName = method.getName();
                }
                this.addBean(methodBeanName, methodBean);
            } catch (IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public Container addBean(String name, Object object) {
        context.put(name, object);
        log.info("添加bean name: %s, bean objec: %s".formatted(name, object));
        return this;
    }

    private void injectField(Object object) {
        for (Field field : object.getClass().getDeclaredFields()) {
            if (field.isAnnotationPresent(Inject.class)) {
                try {
                    field.setAccessible(true);
                    String injectBeanName;
                    if (field.isAnnotationPresent(Named.class)) {
                        injectBeanName = field.getAnnotation(Named.class).value();
                    } else {
                        injectBeanName = field.getName();
                    }
                    Object injectBean = context.get(injectBeanName);
                    if (injectBean == null) {
                        suspendBeans.add(new SuspendBean(object, field));
                    }
                    field.set(object, injectBean);
                    field.setAccessible(false);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private void createMethodBean(Object object) {
        Method[] methods = object.getClass().getDeclaredMethods();
        for (Method method : methods) {
            if (method.isAnnotationPresent(Singleton.class) && !method.getReturnType().isAssignableFrom(void.class)
                    && !method.getReturnType().isAssignableFrom(Void.class)) {
                try {
                    List<Object> parameterBeans = new ArrayList<>();
                    boolean skipFlag = false;
                    for (Parameter parameter : method.getParameters()) {
                        Optional<?> optional = this.getBeanByType(parameter.getType());
                        if (optional.isPresent()) {
                            parameterBeans.add(optional.get());
                        } else {
                            skipFlag = true;
                        }
                    }
                    if (skipFlag) {
                        suspendBeanMethods.add(new SuspendBeanMethod(object, method));
                        continue;
                    }
                    Object methodBean = method.invoke(object, parameterBeans.toArray());
                    String methodBeanName;
                    if (method.isAnnotationPresent(Named.class)) {
                        methodBeanName = method.getAnnotation(Named.class).value();
                    } else {
                        methodBeanName = method.getName();
                    }
                    this.addBean(methodBeanName, methodBean);
                } catch (IllegalAccessException | InvocationTargetException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    @SuppressWarnings("unchecked")
    public <T> T getBean(String name, Class<T> clazz) {
        Object bean = context.get(name);
        if (bean != null && clazz.isAssignableFrom(bean.getClass())) {
            return (T) context.get(name);
        } else {
            throw new RuntimeException("no bean named '%s' with %s type".formatted(name, clazz.getSimpleName()));
        }
    }

    record SuspendBean(Object beanObject, Field field) {
    }

    record SuspendBeanMethod(Object beanObject, Method method) {
    }

}

2、支持

字段注入

@Singleton
@Named("myBananaService")
public class BananaService {

    @Inject
    private AppleService appleService;

    public void banana() {
        appleService.apple();
        System.out.println("banana");
    }

}

方法参数注入

@Singleton
public class AppConfig {

    @Singleton
    public HelloService helloService() {
        return new HelloService();
    }

    @Singleton
    public WorldService worldService(HelloService helloService) {
        System.out.println("开始初始化: worldService");
        return new WorldService(helloService);
    }

}

record注入

@Singleton
public record RecordService(HelloService helloService) {

    public void record() {
        System.out.println("record service");
        helloService.hello();
    }

}

2、使用

public class DemoApp {

    public static void main(String[] args) {
        Container container = new Container(DemoApp.class);
        AppleService appleService = container.getBean("appleService", AppleService.class);
        BananaService bananaService = container.getBean("myBananaService", BananaService.class);
        CherryService cherryService = container.getBean("cherryService", CherryService.class);
        WorldService worldService = container.getBean("worldService", WorldService.class);
        RecordService recordService = container.getBean("recordService", RecordService.class);
        appleService.apple();
        bananaService.banana();
        cherryService.cherry();
        worldService.world();
        recordService.record();
    }

}
posted @   漠孤烟  阅读(12)  评论(0编辑  收藏  举报
编辑推荐:
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
点击右上角即可分享
微信分享提示