本文已参与「新人创作礼」活动,一起开启掘金创作之路。
本文主要记录了基于整个Mybatis的使用以及源码分析篇总结,为了更加清晰的看懂Mybatis原理,简化源码手写实现一个简单的Mybatis版本
1. 项目依赖以及架构
- pom 基于springboot 2.5.6
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.18</version>
</dependency>
- 目录
2.自动装配类 MapperAutoConfig
- 只贴了一个扫描和向容器注册的方法
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
//1.扫描mapper接口
doScanner("com.example.demo", TMapper.class, this.configuration.getMapperInterfaceList());
//初始化插件
doScanner("com.example.demo", Intercepts.class, this.configuration.getPluGinInterfaceList());
//2.扫描sql 这里方便读取存放到配置文件了
doScannerStatement("application.yml");
//3.将mapper接口和sql文件保存到MapperFactoryBean,然后注册到容器中去
doRegister(registry);
}
//向容器注册
private void doRegister(BeanDefinitionRegistry registry) {
for (Object object : this.configuration.getMapperInterfaceList()) {
Class<?> clazz = (Class<?>) object;
//beanName是mapper接口名
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(clazz);
AbstractBeanDefinition beanDefinition = builder.getBeanDefinition();
//class 替换为MapperFactoryBean类
beanDefinition.setBeanClass(MapperFactoryBean.class);
//参数传递
beanDefinition.getPropertyValues().add("mapperInterface", clazz.getName());
beanDefinition.getPropertyValues().add("configuration", this.configuration);
//注册到容器 beanName是接口全名 beanClass替换成MapperFactoryBean
registry.registerBeanDefinition(clazz.getName(), beanDefinition);
log.info("add mapper for " + clazz.getName());
}
}
3.加载MapperFactoryBean类会调用该类的getObject方法
//该类实现了FactoryBean 初始化原始类时会调用该类的getObject方法
@Data
public class MapperFactoryBean<T> implements FactoryBean<T> {
private Class<T> mapperInterface;
Configuration configuration;
public MapperFactoryBean() {
}
//此类为工厂bean 初始化该类时可以自己定义实例化对象 这里采用动态代理生成返回的对象 调用类为MapperProxy
@Override
public T getObject() throws Exception {
return (T) Proxy.newProxyInstance(mapperInterface.getClassLoader(),
new Class[]{mapperInterface}, new MapperProxy<T>(mapperInterface, configuration));
}
@Override
public Class<T> getObjectType() {
return this.mapperInterface;
}
}
4.代理调用类MapperProxy
public class MapperProxy<T> implements InvocationHandler, Serializable {
Class<T> mapperInterface;
Configuration configuration;
public MapperProxy(Class<T> mapperInterface, Configuration configuration) {
this.mapperInterface = mapperInterface;
this.configuration = configuration;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
//处理sql的相关信息
MapperMethod mapperMethod = new MapperMethod(configuration, method, this.mapperInterface);
//真正的调用从这里开始的
return mapperMethod.execute(args);
}
}
5.sqlseesion代理生成类MapperMethod
public MapperMethod(Configuration configuration, Method method, Class<?> mapperInterface) {
this.method = method;
this.sqlCommend = new SqlCommend();
String statementId = mapperInterface.getName() + "." + method.getName();
sqlCommend.setStatementId(statementId);
sqlCommend.setResultType(method.getReturnType());
sqlCommend.setParamType(method.getParameterTypes());
sqlCommend.setSql(configuration.getMappedStatements().get(statementId));
//创建代理类
this.sqlSessionProxy = createSqlSessionProxy(configuration);
}
private SqlSession createSqlSessionProxy(Configuration configuration) {
//代理invoke类指向SqlSessionInvocation类
return (SqlSession) Proxy.newProxyInstance(this.getClass().getClassLoader(),
new Class[]{SqlSession.class}, new SqlSessionInvocation(configuration, sqlCommend));
}
//根据sql类型区分 先会走SqlSessionInvocation代理类
public Object execute(Object[] args) {
switch (sqlCommend.getType()) {
case "SELECT":
Class<?> returnType = this.method.getReturnType();
if (List.class.equals(returnType)) {
return this.sqlSessionProxy.selectList(sqlCommend.getStatementId(), args);
} else if (!void.class.equals(returnType)) {
return this.sqlSessionProxy.selectOne(sqlCommend.getStatementId(), args);
}
case "UPDATE":
case "INSERT":
return this.sqlSessionProxy.update(sqlCommend.getStatementId(), args);
case "DELETE":
return this.sqlSessionProxy.delete(sqlCommend.getStatementId(), args);
}
throw new RuntimeException("sql类型异常," + sqlCommend.getType());
}
6.SqlSessionInvocation 新sqlSeesion和事务代理类
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
//获取一个新的sqlSession
DefualtSqlSession sqlSession = getSqlSession();
Object result = method.invoke(sqlSession, args);
//做事务的提交 todo
return result;
} catch (Exception e) {
//回滚之类的
throw new RuntimeException(e);
}
}
private DefualtSqlSession getSqlSession() {
//创建执行器
Executor executor = new BaseExecutor(configuration, sqlCommend);
//扫描插件实现
executor = (Executor) this.pluginAll(executor);
return new DefualtSqlSession(executor);
}
private Object pluginAll(Object executor) {
//插件list
List<Object> interfaceList = configuration.getPluGinInterfaceList();
for (Object interceptor : interfaceList) {
Interceptor instance;
try {
//初始化插件 正常这里是从spring拿就好 我这是扫描的class
instance = (Interceptor) ((Class<?>) interceptor).newInstance();
} catch (Exception e) {
continue;
}
//代理插件 这个就是和源码一样了
executor = instance.plugin(executor);
}
return executor;
}
7.sqlSeesion默认实现类DefualtSqlSession
- 这里只贴一个selectList方法了
@Override
public <E> List<E> selectList(String statement, Object parameter) {
//可能是插件的代理类 多层
//执行前sql修改或参数修改 执行后的返回字段修改或者转义
return this.executor.query(statement, parameter);
}
8.插件的执行代理类Plugin
- 插件实现类
@Slf4j
@Component
@Intercepts(method = "com.example.mybatis.executor.BaseExecutor.query")
public class PluginService implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
Object target = invocation.getTarget();
log.info("Sql{}", ((BaseExecutor) target).getSqlCommend().getSql());
Object[] args = invocation.getArgs();
log.info("statementId:{}", args[0]);
log.info("参数:{}", args[1]);
} catch (Exception e) {
log.error("插件异常!");
}
return invocation.process();
}
}
- 插件的invoke代理类
public class Plugin implements InvocationHandler {
Object target;
Interceptor interceptor;
public Plugin(Object target, Interceptor interceptor) {
this.target = target;
this.interceptor = interceptor;
}
public static Object wrap(Object target, Interceptor interceptor) {
Class<?> clazz = target.getClass();
Class<?>[] interfaces = clazz.getInterfaces();
//生成一个新的代理类
return Proxy.newProxyInstance(clazz.getClassLoader(), interfaces, new Plugin(target, interceptor));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
//获取注解上的内容 这里简便了 直接对应statementId
Intercepts annotation = interceptor.getClass().getAnnotation(Intercepts.class);
//arg里面的参数 源码是封装成ms传递的 这里省略了
//就两个参数通过selectList传递过来的 selectList(sqlCommend.getStatementId(), args)
//通过被代理的原始类的方法判断是否走代理逻辑
if (annotation.method().equals(target.getClass().getName()+"."+method.getName())) {
//满足条件就调用插件的实现类
return interceptor.intercept(new Invocation(target, method, args));
}
//否则就回到原始类
return method.invoke(target, args);
}
}
9.BaseExecutor最终的执行类
@Getter
@Slf4j
public class BaseExecutor implements Executor {
Configuration configuration;
SqlCommend sqlCommend;
//配置类和连接简化了
static String driverClassName = "driver-class-name";
static String url = "url";
static String username = "username";
static String password = "password";
public BaseExecutor(Configuration configuration, SqlCommend sqlCommend) {
this.configuration = configuration;
this.sqlCommend = sqlCommend;
}
@Override
public int update(String statement, Object parameter) {
Statement st = null;
try {
//预处理sql
String sql = this.getBoundSql(parameter);
//连接数据库
st = prepareStatement();
//执行查询
st.executeUpdate(sql);
//处理结果集
return st.getUpdateCount();
} catch (Exception e) {
log.error(e.getMessage(), e);
return 0;
} finally {
closeStatement(st);
}
}
@Override
public int delete(String statement, Object parameter) {
Statement st = null;
try {
//预处理sql
String sql = this.getBoundSql(parameter);
//连接数据库
st = prepareStatement();
//执行查询
st.execute(sql);
//处理结果集
return st.getUpdateCount();
} catch (Exception e) {
log.error(e.getMessage(), e);
return 0;
} finally {
closeStatement(st);
}
}
@Override
public <E> List<E> query(String statement, Object parameter) {
Statement st = null;
try {
//预处理sql
String sql = this.getBoundSql(parameter);
//连接数据库
st = prepareStatement();
//执行查询
ResultSet rs = st.executeQuery(sql);
//处理结果集
return handlerResultSet(rs);
} catch (Exception e) {
log.error(e.getMessage(), e);
return null;
} finally {
closeStatement(st);
}
}
private void closeStatement(Statement st) {
if (st != null) {
try {
st.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
//这里只简单实现map和单字段传参
//源码中GenericTokenParser#parse解析 SqlSourceBuilder#handleToken存储字段顺序
// 实现流程 使用下标匹配#{和}得到每个字段的名称 添加到参数Mapping 然后替换成?
// 最后传入的参数会处理成数组的方式传输 交由parameterHandler按顺序取值 parameterHandler可以被插件代理
private String getBoundSql(Object parameter) throws Exception {
String sql = this.sqlCommend.getSql();
Object[] args = (Object[]) parameter;
if (args.length > 0) {
if (args[0] instanceof Map) {
Map<String, Object> map = (Map) args[0];
for (String str : map.keySet()) {
sql = sql.replace("#{" + str + "}",
"'" + map.get(str).toString() + "'");
}
} else {
sql = sql.replaceAll("#\\{.*?}", "%s");
sql = String.format(sql, args);
}
}
log.info("execute sql:{}", sql);
return sql;
}
//resultHandler 可以被插件多层代理
private <E> List<E> handlerResultSet(ResultSet rs) throws Exception {
List<Map<String, Object>> mapList = new ArrayList<>();
Map<String, Object> map = new HashMap<>();
//获取列数据
ResultSetMetaData metaData = rs.getMetaData();
while (rs.next()) {
for (int i = 1; i <= metaData.getColumnCount(); i++) {
map.put(metaData.getColumnName(i), rs.getObject(i));
}
mapList.add(map);
}
log.info("result:{}", mapList);
return (List<E>) mapList;
}
private Statement prepareStatement() {
try {
Properties properties = configuration.getProperties();
Class.forName(properties.getProperty(driverClassName));
//conn正常是放数据池管理的
Connection conn = DriverManager.getConnection(
properties.getProperty(url),
properties.getProperty(username),
properties.getProperty(password));
//这个相当于session 每次用完就关闭
Statement stmt = conn.createStatement();
return stmt;
} catch (Exception e) {
throw new RuntimeException("数据库连接失败,e:" + e);
}
}
}
10.测试日志
10.1 insert测试
@GetMapping("/insert")
public Object insert() {
Map<String, Object> map = new HashMap<>();
map.put("name", new Date().getTime());
map.put("age", 22);
map.put("status", 1);
int i = testMapper.insert(map);
return i;
}
- 日志输出
c.example.mybatis.executor.BaseExecutor : execute sql:insert into user_info(`name`,`age`,`status`) values('1647241298926','22','1')
10.2 update测试
@GetMapping("/update")
public Object update(@RequestParam Long id) {
Map<String, Object> map = new HashMap<>();
map.put("id", id);
map.put("name", "update_" + new Date().getTime());
int i = testMapper.update(map);
return i;
}
- 日志输出
BaseExecutor : execute sql:update user_info set name ='update_1647241489267' where id = '1'
10.3 select测试
@GetMapping("/select")
public Object select(@RequestParam Long id) {
Map<String, Object> map = testMapper.selectById(id);
return map;
}
- 日志测试
com.example.demo.plugin.PluginService : Sqlselect * from user_info where id = #{id}
com.example.demo.plugin.PluginService : statementId:com.example.demo.mapper.TestMapper.selectById
com.example.demo.plugin.PluginService : 参数:[1]
c.example.mybatis.executor.BaseExecutor : execute sql:select * from user_info where id = 1
c.example.mybatis.executor.BaseExecutor : result:[{update_time=2021-09-12 18:17:16.0, create_time=2021-09-12 18:13:50.0, name=update_1647241489267, id=1, age=22, status=1}
大大的简化整个流程,方便读懂mybait整个源码。源码地址直通车 以上就是本章的全部内容了。
上一篇:mybatis第八话 - mybaits之ParameterHandler参数处理源码分析 下一篇:mybatis第十话 - mybaits事务的源码分析
立身以立学为先,立学以读书为本