DeepSeek参考SumCount编写的duckdb调用gmp库高精度聚合函数 - l1t1/note GitHub Wiki
DeepSeek参考SumCount编写的duckdb调用gmp库高精度聚合函数
//在头文件中声明
struct MPZSum{
static void RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog);
};
//在cpp文件中编写
struct MPZSumState {
mpz_t sum; // 使用GMP大整数存储总和
};
struct MPZSumFunction {
template <class STATE>
static void Initialize(STATE &state) {
mpz_init(state.sum); // 初始化GMP整数
}
template <class STATE>
static void Destroy(STATE &state, duckdb::AggregateInputData &aggr_input_data) {
mpz_clear(state.sum); // 清理GMP整数
}
static bool IgnoreNull() {
return true;
}
};
// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
// 获取字符串指针和长度
const char* data;
uint32_t length;
if (raw->value.inlined.length <= 12) {
data = raw->value.inlined.inlined;
length = raw->value.inlined.length;
} else {
data = raw->value.pointer.ptr;
length = raw->value.pointer.length;
}
// 验证纯数字
for (uint32_t i = 0; i < length; i++) {
if (data[i] < '0' || data[i] > '9') {
throw std::runtime_error("Invalid character in number");
}
}
return std::string(data, length);
}
static void MPZSumUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData &, idx_t input_count, duckdb::Vector &state_vector, idx_t count) {
auto &input = inputs[0];
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
duckdb::UnifiedVectorFormat input_data;
input.ToUnifiedFormat(count, input_data);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) {
auto &state = *states[sdata.sel->get_index(i)];
auto str_value = duckdb::UnifiedVectorFormat::GetData<duckdb::string_t>(input_data);
try {
// 使用安全读数函数获取数字字符串
std::string num_str = GetNumericString(str_value[input_data.sel->get_index(i)]);
mpz_t tmp;
mpz_init(tmp);
// 使用更安全的字符串转换函数
if (mpz_set_str(tmp, num_str.c_str(), 10) != 0) {
mpz_clear(tmp);
throw std::runtime_error("Failed to convert string to GMP number");
}
mpz_add(state.sum, state.sum, tmp);
mpz_clear(tmp);
} catch (const std::exception &e) {
// 处理无效数字字符串
throw std::runtime_error("Error processing number: " + std::string(e.what()));
}
}
}
}
static void MPZSumFinalize(duckdb::Vector &state_vector, duckdb::AggregateInputData &, duckdb::Vector &result, idx_t count, idx_t offset) {
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
const auto rid = i + offset;
auto &state = *states[sdata.sel->get_index(i)];
// 使用更可靠的GMP字符串转换
char *str = mpz_get_str(nullptr, 10, state.sum);
if (!str) {
throw std::runtime_error("Failed to convert GMP number to string");
}
try {
duckdb::string_t result_str(str);
duckdb::FlatVector::GetData<duckdb::string_t>(result)[rid] =
duckdb::StringVector::AddString(result, result_str);
free(str);
} catch (...) {
free(str);
throw;
}
}
}
static void MPZSumCombine(duckdb::Vector &state_vector, duckdb::Vector &combined, duckdb::AggregateInputData &, idx_t count) {
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states_ptr = (MPZSumState **)sdata.data;
auto combined_ptr = duckdb::FlatVector::GetData<MPZSumState *>(combined);
for (idx_t i = 0; i < count; i++) {
auto &state = *states_ptr[sdata.sel->get_index(i)];
mpz_add(combined_ptr[i]->sum, combined_ptr[i]->sum, state.sum); // 合并GMP整数
}
}
duckdb::unique_ptr<duckdb::FunctionData> MPZSumBind(duckdb::ClientContext &context, duckdb::AggregateFunction &function, duckdb::vector<duckdb::unique_ptr<duckdb::Expression>> &arguments) {
function.return_type = duckdb::LogicalType::VARCHAR; // 返回字符串类型
return nullptr;
}
duckdb::AggregateFunction GetMPZSumFunction() {
using STATE_TYPE = MPZSumState;
return duckdb::AggregateFunction(
"mpz_sum", // 函数名
{duckdb::LogicalType::VARCHAR}, // 参数类型为字符串
duckdb::LogicalType::VARCHAR, // 返回类型为字符串
duckdb::AggregateFunction::StateSize<STATE_TYPE>, // 状态大小
duckdb::AggregateFunction::StateInitialize<STATE_TYPE, MPZSumFunction>, // 初始化
MPZSumUpdate, // 更新
MPZSumCombine, // 合并
MPZSumFinalize, // 最终化
nullptr, // 简单更新
MPZSumBind, // 绑定
duckdb::AggregateFunction::StateDestroy<STATE_TYPE, MPZSumFunction> // 销毁
);
}
void MPZSum::RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog) {
duckdb::AggregateFunctionSet mpz_sum("mpz_sum");
mpz_sum.AddFunction(GetMPZSumFunction());
duckdb::CreateAggregateFunctionInfo info(mpz_sum);
catalog.CreateFunction(*conn.context, info);
}
//在主函数中注册
int main() {
DuckDB db("/par/test.db");
Connection con(db);
// 注册自定义函数
con.BeginTransaction();
auto &catalog = Catalog::GetSystemCatalog(*con.context);
MPZSum::RegisterFunction(con, catalog);
con.Commit();
...
}
SQL调用方法
========================================
duckdb> select mpz_sum(value) from numbers where id=1;
mpz_sum("value")
VARCHAR
[ Rows: 1]
12345678901234567890
duckdb> select mpz_sum(value) from numbers;
mpz_sum("value")
VARCHAR
[ Rows: 1]
12345678901234569357
duckdb> select mpz_sum(value)over(order by id),id from numbers;
mpz_sum("value") OVER (ORDER BY id) id
VARCHAR INTEGER
[ Rows: 4]
12345678901234567890 1
12345678901234568013 2
12345678901234568469 3
12345678901234569357 4