sharding jdbc读写分离 - 969251639/study GitHub Wiki

对于常规流程,一个sql的执行一般由如下流程
客户端请求sql执行,从sessionFactory中获取session,从session中打开数据库连接,创建statement,执行sql,返回结果

那么如果需要做读写分离的话,那么在打开数据库连接的时候是需要做出是走主库还是走从库的判断,如下图
客户端请求sql执行,从sessionFactory中获取session,从session中打开数据库sharding-jdbc代理连接,代理连接根据sql的类型进行路由,如果是DQL的sql且非hint则走从库,其他走主库,获取真正连接,创建statement,执行sql,返回结果

  1. 获取Sharding-jdbc数据源
//dataSourceMap保存了真正的数据库连接的映射,key是数据库的名称,value是数据源,shardingRuleConfig配置了规则,路由用
ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig, new HashMap<>(0), null);

public static DataSource createDataSource(final Map<String, DataSource> dataSourceMap, final ShardingRuleConfiguration shardingRuleConfig, 
                                              final Map<String, Object> configMap, final Properties props) throws SQLException {
    return new ShardingDataSource(shardingRuleConfig.build(dataSourceMap), configMap, props);
}
  1. 从ShardingDataSource这个包装的数据源中获取数据库连接
public class ShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable {
    ...
    @Override
    public ShardingConnection getConnection() throws SQLException {
        return new ShardingConnection(shardingContext);
    }
    ...
}

返回的还是一个由sharding-jdbc包装后的代理连接,并传递一个负责全局上线文的参数shardingContext

public class ShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable {
    public ShardingDataSource(final ShardingRule shardingRule, final Map<String, Object> configMap, final Properties props) throws SQLException {
        super(shardingRule.getDataSourceMap().values());
        if (!configMap.isEmpty()) {
            ConfigMapContext.getInstance().getShardingConfig().putAll(configMap);
        }
        shardingProperties = new ShardingProperties(null == props ? new Properties() : props);
        int executorSize = shardingProperties.getValue(ShardingPropertiesConstant.EXECUTOR_SIZE);
        executorEngine = new ExecutorEngine(executorSize);
        boolean showSQL = shardingProperties.getValue(ShardingPropertiesConstant.SQL_SHOW);
        //上下文,保存了路由规则,数据库类型,执行引擎,是否打印sql  
        shardingContext = new ShardingContext(shardingRule, getDatabaseType(), executorEngine, showSQL);
    }
    ...
}
  1. 从sharding-jdbc包装后的代理连接获取statement
@RequiredArgsConstructor
public final class ShardingConnection extends AbstractConnectionAdapter {
    ...
    @Override
    public PreparedStatement prepareStatement(final String sql) throws SQLException {
        return new ShardingPreparedStatement(this, sql);
    }
    
    @Override
    public PreparedStatement prepareStatement(final String sql, final int resultSetType, final int resultSetConcurrency) throws SQLException {
        return new ShardingPreparedStatement(this, sql, resultSetType, resultSetConcurrency);
    }
    
    @Override
    public PreparedStatement prepareStatement(final String sql, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) throws SQLException {
        return new ShardingPreparedStatement(this, sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }
    
    @Override
    public PreparedStatement prepareStatement(final String sql, final int autoGeneratedKeys) throws SQLException {
        return new ShardingPreparedStatement(this, sql, autoGeneratedKeys);
    }
    
    @Override
    public PreparedStatement prepareStatement(final String sql, final int[] columnIndexes) throws SQLException {
        return new ShardingPreparedStatement(this, sql, Statement.RETURN_GENERATED_KEYS);
    }
    
    @Override
    public PreparedStatement prepareStatement(final String sql, final String[] columnNames) throws SQLException {
        return new ShardingPreparedStatement(this, sql, Statement.RETURN_GENERATED_KEYS);
    }
    
    @Override
    public Statement createStatement() throws SQLException {
        return new ShardingStatement(this);
    }
    
    @Override
    public Statement createStatement(final int resultSetType, final int resultSetConcurrency) throws SQLException {
        return new ShardingStatement(this, resultSetType, resultSetConcurrency);
    }
    
    @Override
    public Statement createStatement(final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) throws SQLException {
        return new ShardingStatement(this, resultSetType, resultSetConcurrency, resultSetHoldability);
    }
    ...
}

又是一个类似的包装类(这些类的第一个参数都是this,也就是这些statement包装类都持有ShardingConnection的引用,因为ShardingConnection拥有获取真正数据库连接的功能类),随便看一个ShardingPreparedStatement

  1. 执行sql,因为读写分离是在查询上执行,这里看ShardingPreparedStatement的查询处理
public final class ShardingPreparedStatement extends AbstractShardingPreparedStatementAdapter {
    ...
    @Override
    public ResultSet executeQuery() throws SQLException {
        ResultSet result;
        try {
            Collection<PreparedStatementUnit> preparedStatementUnits = route();
            List<ResultSet> resultSets = new PreparedStatementExecutor(
                    getConnection().getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeQuery();
            result = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge(), this);
        } finally {
            clearBatch();
        }
        currentResultSet = result;
        return result;
    }
    ...
}

第一个调用的方法是路由,做数据库分片,数据库主从,都会在此路由上做手脚,路由到具体的执行库,返回执行单元SQLExecutionUnit,然后根据要执行的sql类型pandua是否是DDL还是非DDL(路由时就已经计算好类型),然后调用generatePreparedStatement(非DDL)创建真正的PreparedStatement保存到PreparedStatementUnit中

public final class SQLExecutionUnit {
    
    private final String dataSource;//数据源
    
    private final String sql;//待执行的sql
}


    private Collection<PreparedStatementUnit> route() throws SQLException {
        Collection<PreparedStatementUnit> result = new LinkedList<>();
        routeResult = routingEngine.route(getParameters());
        for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
            SQLType sqlType = routeResult.getSqlStatement().getType();
            Collection<PreparedStatement> preparedStatements;
            if (SQLType.DDL == sqlType) {
                preparedStatements = generatePreparedStatementForDDL(each);
            } else {
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
            routedStatements.addAll(preparedStatements);
            for (PreparedStatement preparedStatement : preparedStatements) {
                replaySetParameter(preparedStatement);
                result.add(new PreparedStatementUnit(each, preparedStatement));
            }
        }
        return result;
    }

    private PreparedStatement generatePreparedStatement(final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
        Connection connection = getConnection().getConnection(sqlExecutionUnit.getDataSource(), routeResult.getSqlStatement().getType());
        return returnGeneratedKeys ? connection.prepareStatement(sqlExecutionUnit.getSql(), Statement.RETURN_GENERATED_KEYS)
                : connection.prepareStatement(sqlExecutionUnit.getSql(), resultSetType, resultSetConcurrency, resultSetHoldability);
    }

从generatePreparedStatement方法发中可以看到调用了getConnection,也就是ShardingConnection的引用,在调用ShardingConnection中的getConnection获取真正的数据库连接,传递了数据源和sql类型

接下来就是做读写分离的关键,回到ShardingConnection类中getConnection方法

    public Connection getConnection(final String dataSourceName, final SQLType sqlType) throws SQLException {
        if (getCachedConnections().containsKey(dataSourceName)) {
            return getCachedConnections().get(dataSourceName);
        }
        DataSource dataSource = shardingContext.getShardingRule().getDataSourceMap().get(dataSourceName);
        Preconditions.checkState(null != dataSource, "Missing the rule of %s in DataSourceRule", dataSourceName);
        String realDataSourceName;
        if (dataSource instanceof MasterSlaveDataSource) {
            NamedDataSource namedDataSource = ((MasterSlaveDataSource) dataSource).getDataSource(sqlType);
            realDataSourceName = namedDataSource.getName();
            if (getCachedConnections().containsKey(realDataSourceName)) {
                return getCachedConnections().get(realDataSourceName);
            }
            dataSource = namedDataSource.getDataSource();
        } else {
            realDataSourceName = dataSourceName;
        }
        Connection result = dataSource.getConnection();
        getCachedConnections().put(realDataSourceName, result);
        replayMethodsInvocation(result);
        return result;
    }
  • 判断缓存是否存在该数据源,存在返回
  • 根据路由规则的结果从上下文中的映射表获取真正的数据源
  • 判断是否是主从数据源,是的话根据sql类型获取一个数据源
  • 打开获取到的数据源的连接(如果缓存有此连接从缓存获取)
    这样就拿到了真正要执行的数据库的连接

这里比较关键的是根据sql类型获取的是主还是从数据库

    public NamedDataSource getDataSource(final SQLType sqlType) {
        if (isMasterRoute(sqlType)) {
            DML_FLAG.set(true);
            return new NamedDataSource(masterSlaveRule.getMasterDataSourceName(), masterSlaveRule.getMasterDataSource());
        }
        String selectedSourceName = masterSlaveRule.getStrategy().getDataSource(masterSlaveRule.getName(), 
                masterSlaveRule.getMasterDataSourceName(), new ArrayList<>(masterSlaveRule.getSlaveDataSourceMap().keySet()));
        DataSource selectedSource = selectedSourceName.equals(masterSlaveRule.getMasterDataSourceName())
                ? masterSlaveRule.getMasterDataSource() : masterSlaveRule.getSlaveDataSourceMap().get(selectedSourceName);
        Preconditions.checkNotNull(selectedSource, "");
        return new NamedDataSource(selectedSourceName, selectedSource);
    }
    private boolean isMasterRoute(final SQLType sqlType) {
        return SQLType.DQL != sqlType || DML_FLAG.get() || HintManagerHolder.isMasterRouteOnly();
    }

可以看到判断是否走主库的路由判断是
(1)sql类型不是DQL,走主库
(2)当前线程设置了DML_FLAG(ThreadLocal)为true,走主库
(3)指定了Hint,走主库

public enum SQLType {
    
    /**
     * Data Query Language.
     * 
     * <p>Such as {@code SELECT}.</p>
     */
    DQL,
    
    /**
     * Data Manipulation Language.
     *
     * <p>Such as {@code INSERT}, {@code UPDATE}, {@code DELETE}.</p>
     */
    DML,
    
    /**
     * Data Definition Language.
     *
     * <p>Such as {@code CREATE}, {@code ALTER}, {@code DROP}, {@code TRUNCATE}.</p>
     */
    DDL,
    
    /**
     * Transaction Control Language.
     *
     * <p>Such as {@code SET}, {@code COMMIT}, {@code ROLLBACK}, {@code SAVEPOIINT}, {@code BEGIN}.</p>
     */
    TCL
}

select语句:DQL
INSERT, UPDATE, DELETE等语句:DML
CREATE, ALTER, DROP, TRUNCATE等语句:DDL
SET, COMMIT, ROLLBACK, SAVEPOIINT, BEGIN等语句:TCL

所有如果isMasterRoute方法返回true,则直接走主库

        if (isMasterRoute(sqlType)) {
            DML_FLAG.set(true);//用于传递当前线程是主库操作
            return new NamedDataSource(masterSlaveRule.getMasterDataSourceName(), masterSlaveRule.getMasterDataSource());
        }

否则根据负载算法,选取一个从库

        //根据负载均衡策略获取从库
        String selectedSourceName = masterSlaveRule.getStrategy().getDataSource(masterSlaveRule.getName(), 
                masterSlaveRule.getMasterDataSourceName(), new ArrayList<>(masterSlaveRule.getSlaveDataSourceMap().keySet()));
        //这里有一个主库判断,基本上不会走主库,不知为啥要这个判断??
        DataSource selectedSource = selectedSourceName.equals(masterSlaveRule.getMasterDataSourceName())
                ? masterSlaveRule.getMasterDataSource() : masterSlaveRule.getSlaveDataSourceMap().get(selectedSourceName);
        Preconditions.checkNotNull(selectedSource, "");
        //将获取到从库信息通过NamedDataSource包装厚返回
        return new NamedDataSource(selectedSourceName, selectedSource);

目前sharding-jdbc只支持两种负载均衡算法,一个是轮询,一个是随机(默认是轮询ROUND_ROBIN)

public interface MasterSlaveLoadBalanceAlgorithm {
    String getDataSource(String name, String masterDataSourceName, List<String> slaveDataSourceNames);
}

@Getter
public final class MasterSlaveRule {
    ...
    public MasterSlaveRule(final String name, final String masterDataSourceName, 
                           final DataSource masterDataSource, final Map<String, DataSource> slaveDataSourceMap, final MasterSlaveLoadBalanceAlgorithm strategy) {
        ...
        this.strategy = null == strategy ? MasterSlaveLoadBalanceAlgorithmType.getDefaultAlgorithmType().getAlgorithm() : strategy;
    }
}
@RequiredArgsConstructor
@Getter
public enum MasterSlaveLoadBalanceAlgorithmType {
    
    ROUND_ROBIN(new RoundRobinMasterSlaveLoadBalanceAlgorithm()),
    RANDOM(new RandomMasterSlaveLoadBalanceAlgorithm());
    
    private final MasterSlaveLoadBalanceAlgorithm algorithm;
    ...
    public static MasterSlaveLoadBalanceAlgorithmType getDefaultAlgorithmType() {
        return ROUND_ROBIN;
    }
}

public final class RoundRobinMasterSlaveLoadBalanceAlgorithm implements MasterSlaveLoadBalanceAlgorithm {
    
    private static final ConcurrentHashMap<String, AtomicInteger> COUNT_MAP = new ConcurrentHashMap<>();
    
    @Override
    public String getDataSource(final String name, final String masterDataSourceName, final List<String> slaveDataSourceNames) {
        AtomicInteger count = COUNT_MAP.containsKey(name) ? COUNT_MAP.get(name) : new AtomicInteger(0);
        COUNT_MAP.putIfAbsent(name, count);
        count.compareAndSet(slaveDataSourceNames.size(), 0);
        return slaveDataSourceNames.get(count.getAndIncrement() % slaveDataSourceNames.size());
    }
}

public final class RandomMasterSlaveLoadBalanceAlgorithm implements MasterSlaveLoadBalanceAlgorithm {
    
    @Override
    public String getDataSource(final String name, final String masterDataSourceName, final List<String> slaveDataSourceNames) {
        return slaveDataSourceNames.get(new Random().nextInt(slaveDataSourceNames.size()));
    }
}
  1. 将获取到的PreparedStatementExecutor执行真正的查询
public final class PreparedStatementExecutor {
    public List<ResultSet> executeQuery() throws SQLException {
        return executorEngine.executePreparedStatement(sqlType, preparedStatementUnits, parameters, new ExecuteCallback<ResultSet>() {
            
            @Override
            public ResultSet execute(final BaseStatementUnit baseStatementUnit) throws Exception {
                return ((PreparedStatement) baseStatementUnit.getStatement()).executeQuery();
            }
        });
    }
}

最后返回结果

⚠️ **GitHub.com Fallback** ⚠️