DeepSeek参考SumCount编写的duckdb调用gmp库高精度聚合函数 - l1t1/note GitHub Wiki

DeepSeek参考SumCount编写的duckdb调用gmp库高精度聚合函数

//在头文件中声明
struct MPZSum{
static void RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog);
};
//在cpp文件中编写
struct MPZSumState {
    mpz_t sum;  // 使用GMP大整数存储总和
};

struct MPZSumFunction {
    template <class STATE>
    static void Initialize(STATE &state) {
        mpz_init(state.sum);  // 初始化GMP整数
    }

    template <class STATE>
    static void Destroy(STATE &state, duckdb::AggregateInputData &aggr_input_data) {
        mpz_clear(state.sum);  // 清理GMP整数
    }

    static bool IgnoreNull() { 
        return true; 
    }
};

// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
    const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
    
    // 获取字符串指针和长度
    const char* data;
    uint32_t length;
    
    if (raw->value.inlined.length <= 12) {
        data = raw->value.inlined.inlined;
        length = raw->value.inlined.length;
    } else {
        data = raw->value.pointer.ptr;
        length = raw->value.pointer.length;
    }
    
    // 验证纯数字
    for (uint32_t i = 0; i < length; i++) {
        if (data[i] < '0' || data[i] > '9') {
            throw std::runtime_error("Invalid character in number");
        }
    }
    
    return std::string(data, length);
}

static void MPZSumUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData &, idx_t input_count, duckdb::Vector &state_vector, idx_t count) {
    auto &input = inputs[0];
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    duckdb::UnifiedVectorFormat input_data;
    input.ToUnifiedFormat(count, input_data);

    auto states = (MPZSumState **)sdata.data;
    for (idx_t i = 0; i < count; i++) {
        if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) {
            auto &state = *states[sdata.sel->get_index(i)];
            auto str_value = duckdb::UnifiedVectorFormat::GetData<duckdb::string_t>(input_data);
            
            try {
                // 使用安全读数函数获取数字字符串
                std::string num_str = GetNumericString(str_value[input_data.sel->get_index(i)]);
                
                mpz_t tmp;
                mpz_init(tmp);
                // 使用更安全的字符串转换函数
                if (mpz_set_str(tmp, num_str.c_str(), 10) != 0) {
                    mpz_clear(tmp);
                    throw std::runtime_error("Failed to convert string to GMP number");
                }
                mpz_add(state.sum, state.sum, tmp);
                mpz_clear(tmp);
            } catch (const std::exception &e) {
                // 处理无效数字字符串
                throw std::runtime_error("Error processing number: " + std::string(e.what()));
            }
        }
    }
}

static void MPZSumFinalize(duckdb::Vector &state_vector, duckdb::AggregateInputData &, duckdb::Vector &result, idx_t count, idx_t offset) {
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    auto states = (MPZSumState **)sdata.data;

    for (idx_t i = 0; i < count; i++) {
        const auto rid = i + offset;
        auto &state = *states[sdata.sel->get_index(i)];
        
        // 使用更可靠的GMP字符串转换
        char *str = mpz_get_str(nullptr, 10, state.sum);
        if (!str) {
            throw std::runtime_error("Failed to convert GMP number to string");
        }
        
        try {
            duckdb::string_t result_str(str);
            duckdb::FlatVector::GetData<duckdb::string_t>(result)[rid] = 
                duckdb::StringVector::AddString(result, result_str);
            free(str);
        } catch (...) {
            free(str);
            throw;
        }
    }
}
static void MPZSumCombine(duckdb::Vector &state_vector, duckdb::Vector &combined, duckdb::AggregateInputData &, idx_t count) {
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    auto states_ptr = (MPZSumState **)sdata.data;
    auto combined_ptr = duckdb::FlatVector::GetData<MPZSumState *>(combined);

    for (idx_t i = 0; i < count; i++) {
        auto &state = *states_ptr[sdata.sel->get_index(i)];
        mpz_add(combined_ptr[i]->sum, combined_ptr[i]->sum, state.sum);  // 合并GMP整数
    }
}

duckdb::unique_ptr<duckdb::FunctionData> MPZSumBind(duckdb::ClientContext &context, duckdb::AggregateFunction &function, duckdb::vector<duckdb::unique_ptr<duckdb::Expression>> &arguments) {
    function.return_type = duckdb::LogicalType::VARCHAR;  // 返回字符串类型
    return nullptr;
}

duckdb::AggregateFunction GetMPZSumFunction() {
    using STATE_TYPE = MPZSumState;

    return duckdb::AggregateFunction(
        "mpz_sum",                                                                 // 函数名
        {duckdb::LogicalType::VARCHAR},                                            // 参数类型为字符串
        duckdb::LogicalType::VARCHAR,                                              // 返回类型为字符串
        duckdb::AggregateFunction::StateSize<STATE_TYPE>,                          // 状态大小
        duckdb::AggregateFunction::StateInitialize<STATE_TYPE, MPZSumFunction>,    // 初始化
        MPZSumUpdate,                                                              // 更新
        MPZSumCombine,                                                             // 合并
        MPZSumFinalize,                                                            // 最终化
        nullptr,                                                                   // 简单更新
        MPZSumBind,                                                                // 绑定
        duckdb::AggregateFunction::StateDestroy<STATE_TYPE, MPZSumFunction>        // 销毁
    );
}

void MPZSum::RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog) {
    duckdb::AggregateFunctionSet mpz_sum("mpz_sum");
    mpz_sum.AddFunction(GetMPZSumFunction());
    duckdb::CreateAggregateFunctionInfo info(mpz_sum);
    catalog.CreateFunction(*conn.context, info);
}
//在主函数中注册
int main() {
    DuckDB db("/par/test.db");
    Connection con(db);

    // 注册自定义函数
    con.BeginTransaction();
    auto &catalog = Catalog::GetSystemCatalog(*con.context);    
    MPZSum::RegisterFunction(con, catalog);
    con.Commit();
    ...
}    

SQL调用方法

========================================
duckdb> select mpz_sum(value) from numbers where id=1;
mpz_sum("value")
VARCHAR
[ Rows: 1]
12345678901234567890


duckdb> select mpz_sum(value) from numbers;
mpz_sum("value")
VARCHAR
[ Rows: 1]
12345678901234569357


duckdb>  select mpz_sum(value)over(order by id),id  from numbers;
mpz_sum("value") OVER (ORDER BY id)     id
VARCHAR INTEGER
[ Rows: 4]
12345678901234567890    1
12345678901234568013    2
12345678901234568469    3
12345678901234569357    4
⚠️ **GitHub.com Fallback** ⚠️