引导DeepSeek编写的调用GMP库在DuckDB中进行高精度运算的程序 - l1t1/note GitHub Wiki

一开始想让他帮助实现一个具有以上功能DuckDB插件,但编译的二进制动态库文件总是加载报错“Invalid Input Error: Failed to load './libgmp_extension.so', The file is not a DuckDB extension. The metadata at the end of the file is invalid”。无法解决。

所以,退而求其次,让他“编写一个c++程序,它调用duckdb动态库和gmp库,读取指定duckdb数据库文件,把指定文本文件中的SQL语句执行结果输出”。

调用duckdb动态库执行文本文件中的SQL语句功能很快实现了,但调用gmp库总是报错“Error: {"exception_type":"Invalid Input","exception_message":"Invalid unicode (byte sequence mismatch) detected in value construction"}”。

调试中发现,调用gmp库计算类似3*4的小数是对的,大数出错。查阅文档看到“String values are stored as a duckdb_string_t. This is a special struct that stores the string inline(if it is short, i.e.,<= 12 bytes) or a pointer to the string data if it is longer than12 bytes.”, 再查询duckdb.hpp源代码, 找到这个结构体的定义如下。

//! The internal representation of a VARCHAR (string_t). If the VARCHAR does not
//! exceed 12 characters, then we inline it. Otherwise, we inline a prefix for faster
//! string comparisons and store a pointer to the remaining characters. This is a non-
//! owning structure, i.e., it does not have to be freed.
typedef struct {
	union {
		struct {
			uint32_t length;
			char prefix[4];
			char *ptr;
		} pointer;
		struct {
			uint32_t length;
			char inlined[12];
		} inlined;
	} value;
} duckdb_string_t;

又在文档中找到一个示例程序代码片段

duckdb_string_t str = vector_data[row];
if (duckdb_string_is_inlined(str)) {
// use inlined string
printf("%.*s\n", str.value.inlined.length, str.value.inlined.inlined);
} else {
// follow string pointer
printf("%.*s\n", str.value.pointer.length, str.value.pointer.ptr);
}

把它们提供给DeepSeek,终于编写出来正确的读写duckdb_string_t程序,然后让他把这两个string_t读写函数代入原来的程序,并支持依次执行文件中多SQL语句,结果如下。

#include <iostream>
#include <fstream>
#include <duckdb.hpp>
#include <gmp.h>
#include <vector>
#include <memory>
// 添加类型别名简化代码
using duckdb::string_t;
using duckdb::LogicalType;
using duckdb::vector;



class GMPSQLExecutor {
public:
    GMPSQLExecutor(const std::string& db_path, const std::string& sql_file)
        : db_path_(db_path), sql_file_(sql_file) {
        // 初始化DuckDB连接(简化配置)
        db_ = std::make_unique<duckdb::DuckDB>(db_path);
        conn_ = std::make_unique<duckdb::Connection>(*db_);
        
        // 初始化GMP
        mpz_init(gmp_value_);
        
        // 注册自定义函数
        registerGMPFunctions();
        
    }

    ~GMPSQLExecutor() {
        mpz_clear(gmp_value_);
    }
    static void print_string_bytes(const string_t& s) {
    std::cout << "String bytes: ";
    const auto& str = s.GetString();
    for(char c : str) {
        printf("%02X ", (unsigned char)c);
    }
    std::cout << std::endl;
}


// 将std::string存入string_t
static duckdb::string_t StoreString(const std::string& input) {
    duckdb::string_t result;
    auto* raw = reinterpret_cast<duckdb_string_t*>(&result);
    
    if (input.size() <= 12) {
        // 内联存储
        raw->value.inlined.length = input.size();
        memcpy(raw->value.inlined.inlined, input.data(), input.size());
    } else {
        // 指针存储(注意:实际使用时需要确保内存生命周期)
        raw->value.pointer.length = input.size();
        memcpy(raw->value.pointer.prefix, input.data(), 4);
        raw->value.pointer.ptr = (char*)malloc(input.size());
        memcpy(raw->value.pointer.ptr, input.data(), input.size());
    }
    return result;
}

// 从string_t读取std::string
static std::string ReadString(const duckdb::string_t& input) {
    const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
    
    if (raw->value.inlined.length <= 12) {
        return std::string(
            raw->value.inlined.inlined,
            raw->value.inlined.length
        );
    } else {
        return std::string(
            raw->value.pointer.ptr,
            raw->value.pointer.length
        );
    }
}


// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
    const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
    
    // 获取字符串指针和长度
    const char* data;
    uint32_t length;
    
    if (raw->value.inlined.length <= 12) {
        data = raw->value.inlined.inlined;
        length = raw->value.inlined.length;
    } else {
        data = raw->value.pointer.ptr;
        length = raw->value.pointer.length;
    }
    
    // 验证纯数字
    for (uint32_t i = 0; i < length; i++) {
        if (data[i] < '0' || data[i] > '9') {
            throw std::runtime_error("Invalid character in number");
        }
    }
    
    return std::string(data, length);
}

static string_t mpz_add_impl(string_t a, string_t b) {
    // 安全获取数字字符串
    std::string a_str, b_str;
    try {
        a_str = GetNumericString(a);
        b_str = GetNumericString(b);
    } catch (const std::exception& e) {
        throw std::runtime_error(std::string("Invalid input: ") + e.what());
    }

    // GMP加法计算
    mpz_t num1, num2, result;
    mpz_init(num1);
    mpz_init(num2);
    mpz_init(result);
    
    if (mpz_set_str(num1, a_str.c_str(), 10) != 0 ||
        mpz_set_str(num2, b_str.c_str(), 10) != 0) {
        mpz_clear(num1);
        mpz_clear(num2);
        mpz_clear(result);
        throw std::runtime_error("Failed to parse number");
    }
    
    mpz_add(result, num1, num2);  // 唯一不同点:使用mpz_add而非mpz_mul
    
    // 构造返回的string_t
    char* res_str = mpz_get_str(nullptr, 10, result);
    duckdb::string_t ret = StoreString(std::string(res_str));
    free(res_str);
    
    mpz_clear(num1);
    mpz_clear(num2);
    mpz_clear(result);
    
    return ret;
}
static string_t mpz_mul_impl(string_t a, string_t b) {

    // 安全获取数字字符串
    std::string a_str, b_str;
    try {
        a_str = GetNumericString(a);
        b_str = GetNumericString(b);
    } catch (const std::exception& e) {
        throw std::runtime_error(std::string("Invalid input: ") + e.what());
    }    

    // GMP计算
    mpz_t num1, num2, result;
    mpz_init(num1);
    mpz_init(num2);
    mpz_init(result);
    
    
    if(mpz_set_str(num1, a_str.c_str(), 10) == -1 ||
       mpz_set_str(num2, b_str.c_str(), 10) == -1) {
        mpz_clear(num1);
        mpz_clear(num2);
        mpz_clear(result);
        throw std::runtime_error("Invalid number format");
    }
    
    mpz_mul(result, num1, num2);
    
    // 构造返回字符串(自动处理存储方式)
    char* res_str = mpz_get_str(nullptr, 10, result);
  
    duckdb::string_t ret = StoreString(std::string(res_str));  // 关键修改点
    free(res_str);
    mpz_clear(num1);
    mpz_clear(num2);
    mpz_clear(result);
    
    return ret;
}


    void registerGMPFunctions() {
        // 使用简化版CreateScalarFunction(自动类型推导)
        conn_->CreateScalarFunction("mpz_add", &GMPSQLExecutor::mpz_add_impl);
        conn_->CreateScalarFunction("mpz_mul", &GMPSQLExecutor::mpz_mul_impl);
        
        /* 或者使用完整版(如果需要指定类型):
        conn_->CreateScalarFunction<string_t, string_t, string_t>(
            "mpz_add",
            {LogicalType::VARCHAR, LogicalType::VARCHAR},
            LogicalType::VARCHAR,
            &GMPSQLExecutor::mpz_add_impl
        );
        */
    }
    
bool execute() {
    // 读取SQL文件
    std::ifstream sql_stream(sql_file_);
    if (!sql_stream.is_open()) {
        std::cerr << "Error opening SQL file: " << sql_file_ << std::endl;
        return false;
    }

    // 读取完整脚本
    std::string script(
        (std::istreambuf_iterator<char>(sql_stream)),
        std::istreambuf_iterator<char>()
    );
    
    bool all_success = true;
    size_t start_pos = 0;
    size_t semicolon_pos;

    // 按分号拆分语句
    while ((semicolon_pos = script.find(';', start_pos)) != std::string::npos) {
        // 提取单条语句
        std::string single_sql = script.substr(
            start_pos, 
            semicolon_pos - start_pos + 1
        );
        
        // 修剪空白字符
        single_sql.erase(0, single_sql.find_first_not_of(" \n\r\t"));
        single_sql.erase(single_sql.find_last_not_of(" \n\r\t") + 1);

        if (!single_sql.empty()) {
            std::cout << "\nExecuting: " 
                      << single_sql.substr(0, 50) 
                      << (single_sql.length() > 50 ? "..." : "")
                      << std::endl;

            try {
                // 执行单条语句
                auto result = conn_->Query(single_sql);
                if (result->HasError()) {
                    std::cerr << "Error in SQL: " << result->GetError() << std::endl;
                    all_success = false;
                    // 继续执行下一条语句
                } else if (result->RowCount() > 0) {
                    result->Print();
                }
            } catch (const std::exception& e) {
                std::cerr << "Exception: " << e.what() << std::endl;
                all_success = false;
            }
        }

        start_pos = semicolon_pos + 1;
    }

    // 检查是否有未处理的尾部语句
    if (start_pos < script.length()) {
        std::string remaining = script.substr(start_pos);
        remaining.erase(0, remaining.find_first_not_of(" \n\r\t"));
        if (!remaining.empty()) {
            std::cerr << "Warning: Unterminated SQL statement (missing semicolon): " 
                      << remaining.substr(0, 50) << "..." << std::endl;
        }
    }

    return all_success;
}

private:
    std::string db_path_;
    std::string sql_file_;
    std::unique_ptr<duckdb::DuckDB> db_;
    std::unique_ptr<duckdb::Connection> conn_;
    mpz_t gmp_value_;
};

int main(int argc, char** argv) {
    if (argc != 3) {
        std::cerr << "Usage: " << argv[0] << " <database_file> <sql_file>" << std::endl;
        return 1;
    }

    try {
        GMPSQLExecutor executor(argv[1], argv[2]);
        if (!executor.execute()) {
            return 1;
        }
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }

    return 0;
}

编译命令行如下,sql_gmp/sql_gmp.cpp是以上源代码路径。/par/duck/build/src是保存官网下载的duckdb动态库文件libduckdb.so的目录。/par/duck/src/include是保存duckdb源代码的目录,其中包含duckdb.hpp。gmp库事先用apt install libgmp-dev命令安装。

export LIBRARY_PATH=/par/duck/build/src
export LD_LIBRARY_PATH=/par/duck/build/src
g++ -std=c++17 -o sql_gmp/sql_gmp sql_gmp/sql_gmp.cpp -lduckdb -lgmp -I /par/duck/src/include

用duckdb CLI建立test.db数据库,并保存退出。

duckdb test.db "CREATE TABLE numbers(id INTEGER, value VARCHAR); INSERT INTO numbers VALUES (1, '12345678901234567890');"

编写query.sql如下

-- CREATE TABLE numbers(value VARCHAR);
INSERT INTO numbers VALUES (2, '123'), (3, '456');
-- SELECT mpz_add(value, '100') FROM numbers;
-- select * from numbers;
SELECT value, mpz_add(mpz_mul(value, '2') , '2324')FROM numbers;

然后用如下命令行调用,就能计算字符串中保存的大整数的和与积。也可以去掉注释,在query.sql中建立新表,插入其他数据。

sql_gmp/sql_gmp test.db query.sql

在query.sql中输入如下代码就能计算大数的阶乘。

with recursive t as (select 1 n, 1::varchar f
union all select n+1,mpz_mul(f, (n+1)::varchar) from t where n<100)
select n||'!='||f from t where n=100;

将上述二进制文件随同duckdb动态库文件复制到其他装有gmp库的同CPU架构机器就能正常执行。

⚠️ **GitHub.com Fallback** ⚠️