引导DeepSeek编写的调用GMP库在DuckDB中进行高精度运算的程序 - l1t1/note GitHub Wiki
一开始想让他帮助实现一个具有以上功能DuckDB插件,但编译的二进制动态库文件总是加载报错“Invalid Input Error: Failed to load './libgmp_extension.so', The file is not a DuckDB extension. The metadata at the end of the file is invalid”。无法解决。
所以,退而求其次,让他“编写一个c++程序,它调用duckdb动态库和gmp库,读取指定duckdb数据库文件,把指定文本文件中的SQL语句执行结果输出”。
调用duckdb动态库执行文本文件中的SQL语句功能很快实现了,但调用gmp库总是报错“Error: {"exception_type":"Invalid Input","exception_message":"Invalid unicode (byte sequence mismatch) detected in value construction"}”。
调试中发现,调用gmp库计算类似3*4的小数是对的,大数出错。查阅文档看到“String values are stored as a duckdb_string_t. This is a special struct that stores the string inline(if it is short, i.e.,<= 12 bytes) or a pointer to the string data if it is longer than12 bytes.”, 再查询duckdb.hpp源代码, 找到这个结构体的定义如下。
//! The internal representation of a VARCHAR (string_t). If the VARCHAR does not
//! exceed 12 characters, then we inline it. Otherwise, we inline a prefix for faster
//! string comparisons and store a pointer to the remaining characters. This is a non-
//! owning structure, i.e., it does not have to be freed.
typedef struct {
union {
struct {
uint32_t length;
char prefix[4];
char *ptr;
} pointer;
struct {
uint32_t length;
char inlined[12];
} inlined;
} value;
} duckdb_string_t;
又在文档中找到一个示例程序代码片段
duckdb_string_t str = vector_data[row];
if (duckdb_string_is_inlined(str)) {
// use inlined string
printf("%.*s\n", str.value.inlined.length, str.value.inlined.inlined);
} else {
// follow string pointer
printf("%.*s\n", str.value.pointer.length, str.value.pointer.ptr);
}
把它们提供给DeepSeek,终于编写出来正确的读写duckdb_string_t程序,然后让他把这两个string_t读写函数代入原来的程序,并支持依次执行文件中多SQL语句,结果如下。
#include <iostream>
#include <fstream>
#include <duckdb.hpp>
#include <gmp.h>
#include <vector>
#include <memory>
// 添加类型别名简化代码
using duckdb::string_t;
using duckdb::LogicalType;
using duckdb::vector;
class GMPSQLExecutor {
public:
GMPSQLExecutor(const std::string& db_path, const std::string& sql_file)
: db_path_(db_path), sql_file_(sql_file) {
// 初始化DuckDB连接(简化配置)
db_ = std::make_unique<duckdb::DuckDB>(db_path);
conn_ = std::make_unique<duckdb::Connection>(*db_);
// 初始化GMP
mpz_init(gmp_value_);
// 注册自定义函数
registerGMPFunctions();
}
~GMPSQLExecutor() {
mpz_clear(gmp_value_);
}
static void print_string_bytes(const string_t& s) {
std::cout << "String bytes: ";
const auto& str = s.GetString();
for(char c : str) {
printf("%02X ", (unsigned char)c);
}
std::cout << std::endl;
}
// 将std::string存入string_t
static duckdb::string_t StoreString(const std::string& input) {
duckdb::string_t result;
auto* raw = reinterpret_cast<duckdb_string_t*>(&result);
if (input.size() <= 12) {
// 内联存储
raw->value.inlined.length = input.size();
memcpy(raw->value.inlined.inlined, input.data(), input.size());
} else {
// 指针存储(注意:实际使用时需要确保内存生命周期)
raw->value.pointer.length = input.size();
memcpy(raw->value.pointer.prefix, input.data(), 4);
raw->value.pointer.ptr = (char*)malloc(input.size());
memcpy(raw->value.pointer.ptr, input.data(), input.size());
}
return result;
}
// 从string_t读取std::string
static std::string ReadString(const duckdb::string_t& input) {
const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
if (raw->value.inlined.length <= 12) {
return std::string(
raw->value.inlined.inlined,
raw->value.inlined.length
);
} else {
return std::string(
raw->value.pointer.ptr,
raw->value.pointer.length
);
}
}
// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
// 获取字符串指针和长度
const char* data;
uint32_t length;
if (raw->value.inlined.length <= 12) {
data = raw->value.inlined.inlined;
length = raw->value.inlined.length;
} else {
data = raw->value.pointer.ptr;
length = raw->value.pointer.length;
}
// 验证纯数字
for (uint32_t i = 0; i < length; i++) {
if (data[i] < '0' || data[i] > '9') {
throw std::runtime_error("Invalid character in number");
}
}
return std::string(data, length);
}
static string_t mpz_add_impl(string_t a, string_t b) {
// 安全获取数字字符串
std::string a_str, b_str;
try {
a_str = GetNumericString(a);
b_str = GetNumericString(b);
} catch (const std::exception& e) {
throw std::runtime_error(std::string("Invalid input: ") + e.what());
}
// GMP加法计算
mpz_t num1, num2, result;
mpz_init(num1);
mpz_init(num2);
mpz_init(result);
if (mpz_set_str(num1, a_str.c_str(), 10) != 0 ||
mpz_set_str(num2, b_str.c_str(), 10) != 0) {
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
throw std::runtime_error("Failed to parse number");
}
mpz_add(result, num1, num2); // 唯一不同点:使用mpz_add而非mpz_mul
// 构造返回的string_t
char* res_str = mpz_get_str(nullptr, 10, result);
duckdb::string_t ret = StoreString(std::string(res_str));
free(res_str);
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
return ret;
}
static string_t mpz_mul_impl(string_t a, string_t b) {
// 安全获取数字字符串
std::string a_str, b_str;
try {
a_str = GetNumericString(a);
b_str = GetNumericString(b);
} catch (const std::exception& e) {
throw std::runtime_error(std::string("Invalid input: ") + e.what());
}
// GMP计算
mpz_t num1, num2, result;
mpz_init(num1);
mpz_init(num2);
mpz_init(result);
if(mpz_set_str(num1, a_str.c_str(), 10) == -1 ||
mpz_set_str(num2, b_str.c_str(), 10) == -1) {
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
throw std::runtime_error("Invalid number format");
}
mpz_mul(result, num1, num2);
// 构造返回字符串(自动处理存储方式)
char* res_str = mpz_get_str(nullptr, 10, result);
duckdb::string_t ret = StoreString(std::string(res_str)); // 关键修改点
free(res_str);
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
return ret;
}
void registerGMPFunctions() {
// 使用简化版CreateScalarFunction(自动类型推导)
conn_->CreateScalarFunction("mpz_add", &GMPSQLExecutor::mpz_add_impl);
conn_->CreateScalarFunction("mpz_mul", &GMPSQLExecutor::mpz_mul_impl);
/* 或者使用完整版(如果需要指定类型):
conn_->CreateScalarFunction<string_t, string_t, string_t>(
"mpz_add",
{LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR,
&GMPSQLExecutor::mpz_add_impl
);
*/
}
bool execute() {
// 读取SQL文件
std::ifstream sql_stream(sql_file_);
if (!sql_stream.is_open()) {
std::cerr << "Error opening SQL file: " << sql_file_ << std::endl;
return false;
}
// 读取完整脚本
std::string script(
(std::istreambuf_iterator<char>(sql_stream)),
std::istreambuf_iterator<char>()
);
bool all_success = true;
size_t start_pos = 0;
size_t semicolon_pos;
// 按分号拆分语句
while ((semicolon_pos = script.find(';', start_pos)) != std::string::npos) {
// 提取单条语句
std::string single_sql = script.substr(
start_pos,
semicolon_pos - start_pos + 1
);
// 修剪空白字符
single_sql.erase(0, single_sql.find_first_not_of(" \n\r\t"));
single_sql.erase(single_sql.find_last_not_of(" \n\r\t") + 1);
if (!single_sql.empty()) {
std::cout << "\nExecuting: "
<< single_sql.substr(0, 50)
<< (single_sql.length() > 50 ? "..." : "")
<< std::endl;
try {
// 执行单条语句
auto result = conn_->Query(single_sql);
if (result->HasError()) {
std::cerr << "Error in SQL: " << result->GetError() << std::endl;
all_success = false;
// 继续执行下一条语句
} else if (result->RowCount() > 0) {
result->Print();
}
} catch (const std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
all_success = false;
}
}
start_pos = semicolon_pos + 1;
}
// 检查是否有未处理的尾部语句
if (start_pos < script.length()) {
std::string remaining = script.substr(start_pos);
remaining.erase(0, remaining.find_first_not_of(" \n\r\t"));
if (!remaining.empty()) {
std::cerr << "Warning: Unterminated SQL statement (missing semicolon): "
<< remaining.substr(0, 50) << "..." << std::endl;
}
}
return all_success;
}
private:
std::string db_path_;
std::string sql_file_;
std::unique_ptr<duckdb::DuckDB> db_;
std::unique_ptr<duckdb::Connection> conn_;
mpz_t gmp_value_;
};
int main(int argc, char** argv) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <database_file> <sql_file>" << std::endl;
return 1;
}
try {
GMPSQLExecutor executor(argv[1], argv[2]);
if (!executor.execute()) {
return 1;
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
编译命令行如下,sql_gmp/sql_gmp.cpp
是以上源代码路径。/par/duck/build/src
是保存官网下载的duckdb动态库文件libduckdb.so
的目录。/par/duck/src/include
是保存duckdb源代码的目录,其中包含duckdb.hpp
。gmp库事先用apt install libgmp-dev
命令安装。
export LIBRARY_PATH=/par/duck/build/src
export LD_LIBRARY_PATH=/par/duck/build/src
g++ -std=c++17 -o sql_gmp/sql_gmp sql_gmp/sql_gmp.cpp -lduckdb -lgmp -I /par/duck/src/include
用duckdb CLI建立test.db数据库,并保存退出。
duckdb test.db "CREATE TABLE numbers(id INTEGER, value VARCHAR); INSERT INTO numbers VALUES (1, '12345678901234567890');"
编写query.sql如下
-- CREATE TABLE numbers(value VARCHAR);
INSERT INTO numbers VALUES (2, '123'), (3, '456');
-- SELECT mpz_add(value, '100') FROM numbers;
-- select * from numbers;
SELECT value, mpz_add(mpz_mul(value, '2') , '2324')FROM numbers;
然后用如下命令行调用,就能计算字符串中保存的大整数的和与积。也可以去掉注释,在query.sql中建立新表,插入其他数据。
sql_gmp/sql_gmp test.db query.sql
在query.sql中输入如下代码就能计算大数的阶乘。
with recursive t as (select 1 n, 1::varchar f
union all select n+1,mpz_mul(f, (n+1)::varchar) from t where n<100)
select n||'!='||f from t where n=100;
将上述二进制文件随同duckdb动态库文件复制到其他装有gmp库的同CPU架构机器就能正常执行。