Spring中,依赖注入是最核心的功能,但依赖注入实现的基本原理非常简单,它通过Map存储Bean,通过反射给Bean注入依赖。
但Spring依赖注入的实现却非常复杂,在Spring中,每一个命名空间对应一个命名空间处理器,每一个标签对应一个标签解析器,并通过属性编辑器给Bean注入依赖等。
通过对Spring的研究,现简单的实现一个IoC,只包含属性注入的功能,也没有使用属性编辑器,主要代码如下,如果需要完整代码,可留下邮箱:
package cn.hang.ioc.api.impl;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import cn.hang.ioc.BeanDefinition;
import cn.hang.ioc.BeanDefinitionLoader;
import cn.hang.ioc.BeanInitializeException;
import cn.hang.ioc.BeanNotFoundException;
import cn.hang.ioc.BeanScope;
import cn.hang.ioc.DefaultBeanDefinitionLoader;
import cn.hang.ioc.DependencyInjectException;
import cn.hang.ioc.api.BeanFactory;
import cn.hang.ioc.io.ResourceLoadException;
import cn.hang.ioc.util.ReflectionUtils;
import cn.hang.ioc.xml.BeanDefinitionRegistry;
import cn.hang.ioc.xml.BeanProperty;
import cn.hang.ioc.xml.XmlParserException;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.reflect.AbstractInvocationHandler;
import com.google.common.reflect.Reflection;
public class ClassPathXmlApplicationContext implements BeanFactory, BeanDefinitionRegistry {
/**
* 存储beanDefinition对象
*/
private Map nameBeanDefinitions = new ConcurrentHashMap();
private Map, BeanDefinition> classBeanDefinitions = new ConcurrentHashMap, BeanDefinition>();
/**
* 加载配置文件并解析出BeanDefinition
*/
private BeanDefinitionLoader beanDefinitionLoader = getBeanDefinitionLoader();
/**
* 存储单例bean的Map
*/
private Map singletonMap = new HashMap();
public ClassPathXmlApplicationContext(String configFile) {
Preconditions.checkNotNull(beanDefinitionLoader);
Preconditions.checkArgument(configFile != null && !configFile.trim().equals(""));
loadResource(configFile);
initalizeSingletonBean();
}
/**
* 加载配置文件,解析出beanDefinition对象
*
* @param configFile
* 配置文件在classpath下的路径
* @throws ResourceLoadException
*/
private void loadResource(String configFile) throws ResourceLoadException {
InputStream in = null;
try {
in = Thread.currentThread().getContextClassLoader().getResourceAsStream(configFile);
beanDefinitionLoader.loadBeans(in);
} catch (XmlParserException e) {
throw new ResourceLoadException(e);
} finally {
if (in != null) {
try {
in.close();
} catch (IOException e) {
throw new ResourceLoadException(e);
}
}
}
}
/**
* 初始化单例的bean
*/
private void initalizeSingletonBean() {
// 初始化bean
for (Map.Entry en : nameBeanDefinitions.entrySet()) {
BeanDefinition beanDefinition = en.getValue();
if (beanDefinition.getScope() == BeanScope.SINGLETON) {
// 单例,初始化bean
Object bean = createBean(beanDefinition);
singletonMap.put(beanDefinition.getName(), bean);
}
}
// 注入依赖
for (Map.Entry en : nameBeanDefinitions.entrySet()) {
BeanDefinition beanDefinition = en.getValue();
if (beanDefinition.getScope() == BeanScope.SINGLETON) {
injectDependencies(singletonMap.get(beanDefinition.getName()), beanDefinition);
}
}
}
@Override
public Object getBean(String name) {
Preconditions.checkArgument(name != null && !name.trim().equals(""));
BeanDefinition beanDefinition = nameBeanDefinitions.get(name);
if (beanDefinition == null) {
throw new BeanNotFoundException(name);
}
Object object = singletonMap.get(name);
if (object == null) {
// prototype
object = createBean(beanDefinition);
}
return object;
}
@SuppressWarnings("unchecked")
@Override
public T getBean(Class clazz) {
Preconditions.checkNotNull(clazz);
BeanDefinition beanDefinition = classBeanDefinitions.get(clazz);
return (T) getBean(beanDefinition.getName());
}
@Override
public void regist(String name, BeanDefinition beanDefinition) {
Preconditions.checkNotNull(beanDefinition);
nameBeanDefinitions.put(name, beanDefinition);
classBeanDefinitions.put(beanDefinition.getBeanClass(), beanDefinition);
}
/**
* 创建bean对象
*
* @param beanDefinition
* 应该被创建的对象信息
* @return 需要的实际对象,不是代理
*/
private Object createBean(BeanDefinition beanDefinition) {
Class c = beanDefinition.getBeanClass();
try {
Object obj = c.newInstance();
return obj;
} catch (InstantiationException e) {
throw new BeanInitializeException(e);
} catch (IllegalAccessException e) {
throw new BeanInitializeException(e);
}
}
/**
* 属性注入bean的依赖对象
*
* @param bean
* 被注入的bean
* @param beanDefinition
*/
private void injectDependencies(Object bean, BeanDefinition beanDefinition) {
List beanProperties = beanDefinition.getDependencies();
for (BeanProperty beanProperty : beanProperties) {
String propName = beanProperty.getPropertyName();
if (Strings.isNullOrEmpty(propName)) {
throw new DependencyInjectException("Property name is must not be null!");
}
String ref = beanProperty.getReference();
if (Strings.isNullOrEmpty(ref)) {
// 注入属性的值
injectValue(bean, beanDefinition, beanProperty);
} else {
// 得到属性对应的bean
Object dependence = singletonMap.get(ref);
if (dependence == null) {
throw new DependencyInjectException(String.format("The reference %s is not exists or is not a Singleton bean!", ref));
}
invokeSetter(bean, propName, dependence);
}
}
}
/**
* 注入属性值
*
* @param bean
* 被注入的bean
* @param beanDefinition
* 被注入的bean的信息
* @param beanProperty
* 被注入的属性
*/
private void injectValue(Object bean, BeanDefinition beanDefinition, BeanProperty beanProperty) {
String propName = beanProperty.getPropertyName();
String value = beanProperty.getValue();
Preconditions.checkState(value != null);
Method setter = getSetterMethod(propName, beanDefinition.getBeanClass());
Class[] params = setter.getParameterTypes();
if (params.length != 1) {
throw new DependencyInjectException(String.format("The setter method of %s is not illegal!", propName));
}
Class param = params[0];
Object arg = convertStringToPrimitive(propName, value, param);
try {
setter.invoke(bean, arg);
} catch (Exception e) {
throw new DependencyInjectException(e);
}
}
/**
* 将字符串转换成原始数据类型
*
* @param propName
* 属性名
* @param value
* 字符串值
* @param param
* 需要转换成的类型
* @return 转换后的值
*/
private Object convertStringToPrimitive(String propName, String value, Class param) {
Object arg = null;
if (param == String.class) {
arg = value;
} else if (param == Integer.class || param == int.class) {
arg = Integer.parseInt(value);
} else if (param == Long.class || param == long.class) {
arg = Long.parseLong(value);
} else if (param == Short.class || param == short.class) {
arg = Short.parseShort(value);
} else if (param == Double.class || param == double.class) {
arg = Double.parseDouble(value);
} else if (param == Float.class || param == float.class) {
arg = Float.parseFloat(value);
} else if (param == Character.class || param == char.class) {
if (value.length() != 1) {
throw new DependencyInjectException("The length of a character is must 1!");
}
arg = value.charAt(0);
} else if (param == Byte.class || param == byte.class) {
arg = Byte.valueOf(value);
}
if (arg == null) {
throw new DependencyInjectException(String.format("The dependency of %s is must not be null", propName));
}
return arg;
}
/**
* 获取BeanDefinitionLoader对象,用于加载配置文件
*
* @return 默认的BeanDefinitionLoader
*/
protected BeanDefinitionLoader getBeanDefinitionLoader() {
return new DefaultBeanDefinitionLoader(this);
}
/**
* 调用属性的setter方法
*
* @param target
* 目标对象
* @param property
* 属性名
* @param arg
* 方法调用的参数
*/
private void invokeSetter(Object target, String property, Object arg) {
Preconditions.checkNotNull(target);
Preconditions.checkArgument(!Strings.isNullOrEmpty(property));
Class c = target.getClass();
Method setter = getSetterMethod(property, c);
if (setter == null) {
throw new DependencyInjectException(String.format("The property %s is not exists or the setter method is not visible!", property));
}
try {
setter.invoke(target, arg);
} catch (Exception e) {
throw new DependencyInjectException(e);
}
}
private Method getSetterMethod(String property, Class c) {
String methodName = "set" + Character.toUpperCase(property.charAt(0)) + property.substring(1);
Method setter = ReflectionUtils.lookupMethod(c, methodName);
return setter;
}
}
package cn.hang.ioc;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.Namespace;
import org.dom4j.io.SAXReader;
import cn.hang.ioc.util.ClassPath;
import cn.hang.ioc.util.PropertiesUtils;
import cn.hang.ioc.xml.BeanDefinitionRegistry;
import cn.hang.ioc.xml.NamespaceHandler;
import cn.hang.ioc.xml.NamespaceHandlerLoadException;
import cn.hang.ioc.xml.XmlParserException;
import com.google.common.base.Preconditions;
/**
* 配置文件的读取。此类中通过读取handler文件获取命名空间解析对象,handler文件中必须保存命名空间URL和解析器的映射
*
* @author hang.gao
*
*/
public class DefaultBeanDefinitionLoader implements BeanDefinitionLoader {
/**
* 支持注册bean的依赖注入容器
*/
private BeanDefinitionRegistry beanDefinitionRegistry;
/**
* 类路径下的IOC容器命名空间与命名空间解析器的映射文件
*/
private static final String CLASSPATH_IOC_HANDLER_MAP = "META-INF/ioc.handlers";
/**
* 持有NamespaceMap
*/
private NamespaceHolder namespaceHolder = new NamespaceHolder();
public DefaultBeanDefinitionLoader(BeanDefinitionRegistry beanDefinitionRegistry) {
Preconditions.checkNotNull(beanDefinitionRegistry);
this.beanDefinitionRegistry = beanDefinitionRegistry;
}
@Override
public void loadBeans(InputStream in) throws XmlParserException {
Preconditions.checkNotNull(in);
SAXReader xmlReader = new SAXReader();
try {
// 根节点
Document document = xmlReader.read(in);
Element element = document.getRootElement();
registNodeDeclaredNamespaceHandlers();
// 处理子标签
parseChildren(element);
} catch (Exception e) {
throw new XmlParserException(e);
}
}
/**
* 解析子标签,根的每个子标签对应一个BeanDefinitions
*
* @param element
*/
@SuppressWarnings("unchecked")
private void parseChildren(Element element) {
// 子节点,每个子节点对应一个beanDefinition
List elems = element.elements();
// 对每个节点,调用其命名空间解析器,由命名空间解析器调用标签解析器解析标签
for (Element elem : elems) {
NamespaceHandler handler = namespaceHolder.get(elem.getNamespace());
if (handler == null) {
throw new XmlParserException("No namespace found:" + elem.getNamespace());
}
// 处理标签
handler.parseTag(elem);
}
}
/**
* 读取特定路径(META-INF)下的ioc.handlers文件
* @throws IOException
*/
private void registNodeDeclaredNamespaceHandlers() throws IOException {
for (URL url : ClassPath.current().getResources(CLASSPATH_IOC_HANDLER_MAP)) {
namespaceHolder.putAll(loadHandlerMap(url));
}
}
/**
* 从指定类路径下的文件中加载命名空间与解析器的对应关系
*
* @param url
* 读取文件的url路径
* @return 从件对应的Properties
*/
private Properties loadHandlerMap(URL url) {
return PropertiesUtils.loadHandlerMap(url);
}
/**
* 用于控制对存储命名空间与命名空间解析器的映射的访问
*
* @author hang.gao
*
*/
private class NamespaceHolder {
private Map namespaceHandlerMap = new HashMap();
public NamespaceHandler get(Namespace key) {
return namespaceHandlerMap.get(key.getURI());
}
private NamespaceHandler put(String key, NamespaceHandler value) {
// 调用namespaceHandler的初始化方法
value.init();
return namespaceHandlerMap.put(key, value);
}
public void putAll(Properties properties) {
for (Map.Entry e : properties.entrySet()) {
NamespaceHandler namespaceHandler = getNamespaceHandlerByClassName((String) e.getValue());
put(e.getKey().toString(), namespaceHandler);
}
}
/**
* 通过NamespaceHandler的类名创建对象
*
* @param className
* 类名,此类名表示的类一定要实现NamespaceHandler接口
* @return 类对应的对象
*/
private NamespaceHandler getNamespaceHandlerByClassName(String className) {
NamespaceHandler namespaceHandler = null;
try {
namespaceHandler = (NamespaceHandler) Class.forName(className).newInstance();
namespaceHandler.setBeanDefinitionRegistry(beanDefinitionRegistry);
} catch (Exception e) {
throw new NamespaceHandlerLoadException(e);
}
return namespaceHandler;
}
}
}
package cn.hang.ioc.xml;
import java.util.Collections;
import java.util.List;
import org.dom4j.Attribute;
import org.dom4j.Element;
import cn.hang.ioc.BeanDefinition;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
/**
* bean标签的解析器
*
* @author hang.gao
*
*/
public class BeanTagHandler implements TagHandler {
/**
* 配置中的name属性名
*/
private static final String NAME = "name";
/**
* 配置中的class属性名
*/
private static final String CLASS = "class";
/**
* 属性名,在property标签中用于指定bean的一个属性的名字
*/
private static final String PROPERTY_NAME = "name";
/**
* 依赖的对象的bean name
*/
private static final String PROPERTY_REF = "ref";
/**
* 属性的值
*/
private static final String PROPERTY_VALUE = "value";
@Override
public void parse(Element element, BeanDefinitionRegistry registry) {
BeanDefinition beanDefinition = new BeanDefinition();
String name = extractNotNoneAttributeValues(element, NAME);
beanDefinition.setName(name);
String clazz = extractNotNoneAttributeValues(element, CLASS);
beanDefinition.setClazz(clazz);
try {
Class c = Class.forName(clazz);
beanDefinition.setBeanClass(c);
} catch (ClassNotFoundException e) {
throw new XmlParserException(e);
}
// 抽取配置的依赖
List properties = extractDependencies(element, beanDefinition);
beanDefinition.setDependencies(properties);
// 注册bean
registry.regist(name, beanDefinition);
}
/**
* 抽取属性值,对属性值进行检查,如果为空则抛出异常
*
* @param element
* 当前XML元素
* @param name
* 需要取得值的属性名
* @return 属性值
*/
private String extractNotNoneAttributeValues(Element element, String name) {
String value = extractAttributeValue(element, name);
Preconditions.checkState(value != null && !value.trim().equals(""));
return value;
}
/**
* 解析出当前bean所配置的属性依赖
*
* @param element
* 当前bean在XML中的元素
* @param beanDefinition
* 已经解析出的bean的信息,包括name和class
* @return 配置的属性依赖集合
*/
private List extractDependencies(Element element, BeanDefinition beanDefinition) {
@SuppressWarnings("unchecked")
List propertyElements = element.elements();
if (propertyElements == null || propertyElements.isEmpty()) {
return Collections.emptyList();
}
List properties = Lists.newArrayList();
for (Element elem : propertyElements) {
BeanProperty beanProperty = new BeanProperty(beanDefinition.getName());
String propertyName = this.extractAttributeValue(elem, PROPERTY_NAME);
if (Strings.isNullOrEmpty(propertyName)) {
throw new XmlParserException("The property name is must not empty!:" + beanDefinition.getName());
}
beanProperty.setPropertyName(propertyName);
String propertyRef = this.extractAttributeValue(elem, PROPERTY_REF);
if (Strings.isNullOrEmpty(propertyRef)) {
String propertyValue = this.extractNotNoneAttributeValues(elem, PROPERTY_VALUE);
beanProperty.setValue(propertyValue);
} else {
beanProperty.setReference(propertyRef);
}
properties.add(beanProperty);
}
return properties;
}
/**
* 从element中抽取标签的属性值
*
* @param element
* 当前XML元素
* @param attrName
* 需要抽象值的属性名
* @return 属性的值
*/
private String extractAttributeValue(Element element, String attrName) {
Attribute attribute = element.attribute(attrName);
Preconditions.checkNotNull(attribute);
String value = attribute.getValue();
// Preconditions.checkState(!Strings.isNullOrEmpty(value));
return value;
}
}