orm 中 lnsert 操作怎么获取生成的主键id值
mybatis 中的实现
public interface UserMapper extends BaseMapper<User> {
@Insert("INSERT INTO user (username, password) VALUES (#{user.username}, #{user.password})")
@Options(useGeneratedKeys = true, keyProperty = "user.id")
int insertUser(@Param("user") User user);
}
public class OptionsLanguageDriver extends XMLLanguageDriver implements LanguageDriver {
@Override
public SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
// 获取原始的SqlSource
SqlSource originalSqlSource = super.createSqlSource(configuration, script, parameterType);
// 判断SqlCommandType是否为INSERT
MappedStatement mappedStatement = configuration.getMappedStatement(script.getStringAttribute("id"));
if (mappedStatement.getSqlCommandType() == SqlCommandType.INSERT) {
// 创建OptionsSqlSource,并将原始的SqlSource作为委托
return new OptionsSqlSource(originalSqlSource);
}
return originalSqlSource;
}
private static class OptionsSqlSource implements SqlSource {
private final SqlSource delegate;
public OptionsSqlSource(SqlSource delegate) {
this.delegate = delegate;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
// 获取原始的BoundSql
BoundSql originalBoundSql = delegate.getBoundSql(parameterObject);
// 创建OptionsBoundSql,并将原始的BoundSql作为委托
return new OptionsBoundSql(originalBoundSql);
}
}
private static class OptionsBoundSql implements BoundSql {
private final BoundSql delegate;
public OptionsBoundSql(BoundSql delegate) {
this.delegate = delegate;
}
@Override
public String getSql() {
// 获取原始的SQL语句
String originalSql = delegate.getSql();
// 添加获取自动生成主键的语句
return originalSql + " SELECT LAST_INSERT_ID()";
}
// 其他方法的委托实现...
}
}
在 JdbcTemplate 基础上实现
protected Integer insert(T t, Boolean ignoreNull) {
String table = getTableName(t);
List<Field> filterField = getField(t, ignoreNull);
List<String> columnList = getColumns(filterField);
String columns = StrUtil.join(Const.SEPARATOR_COMMA, columnList);
// 构造占位符
String params = StrUtil.repeatAndJoin("?", columnList.size(), Const.SEPARATOR_COMMA);
// 构造值
Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();
String sql = StrUtil.format("INSERT INTO {table} ({columns}) VALUES ({params})", Dict.create().set("table", table).set("columns", columns).set("params", params));
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
int update = jdbcTemplate.update(sql, values);
Long generatedId = jdbcTemplate.queryForObject("SELECT LAST_INSERT_ID()", Long.class);
Field pkField = getPkField(t);
if (pkField == null) {
throw new PrimaryKeyMissingException();
}
pkField.setAccessible(true);
pkField.set(t, generatedId);
return update;
}
封装 JdbcTemplate
BaseDao<T, P>
@Slf4j
public class BaseDao<T, P> {
private final JdbcTemplate jdbcTemplate;
private final Class<T> clazz;
@SuppressWarnings(value = "unchecked")
public BaseDao(JdbcTemplate jdbcTemplate) {
this.jdbcTemplate = jdbcTemplate;
clazz = (Class<T>) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[0];
}
/**
* 通用插入,自增列需要添加 {@link Pk} 注解
*
* @param t 对象
* @param ignoreNull 是否忽略 null 值
* @return 操作的行数
*/
@SneakyThrows
protected Integer insert(T t, Boolean ignoreNull) {
String table = getTableName(t);
List<Field> filterField = getField(t, ignoreNull);
List<String> columnList = getColumns(filterField);
String columns = StrUtil.join(Const.SEPARATOR_COMMA, columnList);
// 构造占位符
String params = StrUtil.repeatAndJoin("?", columnList.size(), Const.SEPARATOR_COMMA);
// 构造值
Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();
String sql = StrUtil.format("INSERT INTO {table} ({columns}) VALUES ({params})", Dict.create().set("table", table).set("columns", columns).set("params", params));
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
int update = jdbcTemplate.update(sql, values);
Long generatedId = jdbcTemplate.queryForObject("SELECT LAST_INSERT_ID()", Long.class);
Field pkField = getPkField(t);
if (pkField == null) {
throw new PrimaryKeyMissingException();
}
pkField.setAccessible(true);
pkField.set(t, generatedId);
return update;
}
/**
* 通用根据主键删除
*
* @param pk 主键
* @return 影响行数
*/
protected Integer deleteById(P pk) {
String tableName = getTableName();
String sql = StrUtil.format("DELETE FROM {table} where id = ?", Dict.create().set("table", tableName));
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", pk);
return jdbcTemplate.update(sql, pk);
}
/**
* 通用根据主键更新,自增列需要添加 {@link Pk} 注解
*
* @param t 对象
* @param pk 主键
* @param ignoreNull 是否忽略 null 值
* @return 操作的行数
*/
protected Integer updateById(T t, P pk, Boolean ignoreNull) {
String tableName = getTableName(t);
List<Field> filterField = getField(t, ignoreNull);
List<String> columnList = getColumns(filterField);
List<String> columns = columnList.stream().map(s -> StrUtil.appendIfMissing(s, " = ?")).collect(Collectors.toList());
String params = StrUtil.join(Const.SEPARATOR_COMMA, columns);
// 构造值
List<Object> valueList = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).collect(Collectors.toList());
valueList.add(pk);
Object[] values = ArrayUtil.toArray(valueList, Object.class);
String sql = StrUtil.format("UPDATE {table} SET {params} where id = ?", Dict.create().set("table", tableName).set("params", params));
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
return jdbcTemplate.update(sql, values);
}
/**
* 通用根据主键查询单条记录
*
* @param pk 主键
* @return 单条记录
*/
public T findOneById(P pk) {
String tableName = getTableName();
String sql = StrUtil.format("SELECT * FROM {table} where id = ?", Dict.create().set("table", tableName));
RowMapper<T> rowMapper = new BeanPropertyRowMapper<>(clazz);
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(pk));
return jdbcTemplate.queryForObject(sql, new Object[]{pk}, rowMapper);
}
/**
* 根据对象查询
*
* @param t 查询条件
* @return 对象列表
*/
public List<T> findByExample(T t) {
String tableName = getTableName(t);
List<Field> filterField = getField(t, true);
List<String> columnList = getColumns(filterField);
List<String> columns = columnList.stream().map(s -> " and " + s + " = ? ").collect(Collectors.toList());
String where = StrUtil.join(" ", columns);
// 构造值
Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();
String sql = StrUtil.format("SELECT * FROM {table} where 1=1 {where}", Dict.create().set("table", tableName).set("where", StrUtil.isBlank(where) ? "" : where));
log.debug("【执行SQL】SQL:{}", sql);
log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql, values);
List<T> ret = CollUtil.newArrayList();
maps.forEach(map -> ret.add(BeanUtil.fillBeanWithMap(map, ReflectUtil.newInstance(clazz), true, false)));
return ret;
}
/**
* 获取表名
*
* @param t 对象
* @return 表名
*/
private String getTableName(T t) {
Table tableAnnotation = t.getClass().getAnnotation(Table.class);
if (ObjectUtil.isNotNull(tableAnnotation)) {
return StrUtil.format("`{}`", tableAnnotation.name());
} else {
return StrUtil.format("`{}`", t.getClass().getName().toLowerCase());
}
}
/**
* 获取表名
*
* @return 表名
*/
private String getTableName() {
Table tableAnnotation = clazz.getAnnotation(Table.class);
if (ObjectUtil.isNotNull(tableAnnotation)) {
return StrUtil.format("`{}`", tableAnnotation.name());
} else {
return StrUtil.format("`{}`", clazz.getName().toLowerCase());
}
}
/**
* 获取列
*
* @param fieldList 字段列表
* @return 列信息列表
*/
private List<String> getColumns(List<Field> fieldList) {
// 构造列
List<String> columnList = CollUtil.newArrayList();
for (Field field : fieldList) {
Column columnAnnotation = field.getAnnotation(Column.class);
String columnName;
if (ObjectUtil.isNotNull(columnAnnotation)) {
columnName = columnAnnotation.name();
} else {
columnName = field.getName();
}
columnList.add(StrUtil.format("`{}`", columnName));
}
return columnList;
}
/**
* 获取字段列表 {@code 过滤数据库中不存在的字段,以及自增列}
*
* @param t 对象
* @param ignoreNull 是否忽略空值
* @return 字段列表
*/
private List<Field> getField(T t, Boolean ignoreNull) {
// 获取所有字段,包含父类中的字段
Field[] fields = ReflectUtil.getFields(t.getClass());
// 过滤数据库中不存在的字段,以及自增列
List<Field> filterField;
Stream<Field> fieldStream = CollUtil.toList(fields).stream().filter(field -> ObjectUtil.isNull(field.getAnnotation(Ignore.class)) || ObjectUtil.isNull(field.getAnnotation(Pk.class)));
// 是否过滤字段值为null的字段
if (ignoreNull) {
filterField = fieldStream.filter(field -> ObjectUtil.isNotNull(ReflectUtil.getFieldValue(t, field))).collect(Collectors.toList());
} else {
filterField = fieldStream.collect(Collectors.toList());
}
return filterField;
}
private Field getPkField(T t) {
// 获取所有字段,包含父类中的字段
Field[] fields = ReflectUtil.getFields(t.getClass());
Optional<Field> first = CollUtil.toList(fields).stream().filter(field -> ObjectUtil.isNotNull(field.getAnnotation(Pk.class))).findFirst();
boolean present =first.isPresent();
if (present) {
return first.get();
}
return null;
}
}
UserDao
@Repository
public class UserDao extends BaseDao<User, Long> {
@Autowired
public UserDao(JdbcTemplate jdbcTemplate) {
super(jdbcTemplate);
}
/**
* 保存用户
*
* @param user 用户对象
* @return 操作影响行数
*/
public Integer insert(User user) {
return super.insert(user, true);
}
/**
* 根据主键删除用户
*
* @param id 主键id
* @return 操作影响行数
*/
public Integer delete(Long id) {
return super.deleteById(id);
}
/**
* 更新用户
*
* @param user 用户对象
* @param id 主键id
* @return 操作影响行数
*/
public Integer update(User user, Long id) {
return super.updateById(user, id, true);
}
/**
* 根据主键获取用户
*
* @param id 主键id
* @return id对应的用户
*/
public User selectById(Long id) {
return super.findOneById(id);
}
/**
* 根据查询条件获取用户列表
*
* @param user 用户查询条件
* @return 用户列表
*/
public List<User> selectUserList(User user) {
return super.findByExample(user);
}
}
蓝天和白云是标配。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 周边上新:园子的第一款马克杯温暖上架
· 分享 3 个 .NET 开源的文件压缩处理库,助力快速实现文件压缩解压功能!
· Ollama——大语言模型本地部署的极速利器
· DeepSeek如何颠覆传统软件测试?测试工程师会被淘汰吗?
· 使用C#创建一个MCP客户端