第一个很简单,只要在g_typemap中添加一行{LogicalTypeId::LIST, pgwire::Oid::Varchar},再在duckdb_handler中在case LogicalTypeId::VARCHAR:前添加case LogicalTypeId::LIST:即可。
第二个比较棘手,PIVOT语句因为后台实际上会执行多个语句,无法使用prepare+excute语句执行,而只能用query执行,一开始试图查找语句的开头有PIVOT来分别采用两种方法执行,结果行不通,因为query查询的结果集没有GetNames等成员函数,无法获得相应的信息。将错误信息Invalid Input Error: Cannot prepare multiple statements at once在网上查找,结果在duckdb_ui插件的issue中找到了https://github.com/duckdb/duckdb-ui/issues/57。回帖采取的方法是用ExtractStatements方法把实际执行的语句提取出来,然后先执行前n-1个,最后用prepare执行最后1个,我试了,可以。完整修改后的duckdb_pgwire_extension.cpp如下:
#include "duckdb/common/types.hpp"
#include <unordered_map>
#define DUCKDB_EXTENSION_MAIN
#include <duckdb_pgwire_extension.hpp>
#include <duckdb/common/exception.hpp>
#include <duckdb/common/string_util.hpp>
#include <duckdb/function/scalar_function.hpp>
//#include <duckdb/main/extension_util.hpp>
#include <duckdb/main/extension/extension_loader.hpp>
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
#include <atomic>
#include <optional>
#include <pgwire/exception.hpp>
#include <pgwire/log.hpp>
#include <pgwire/server.hpp>
#include <pgwire/types.hpp>
#include <stdexcept>
namespace duckdb {
static std::atomic<bool> g_started;
static std::unordered_map<LogicalTypeId, pgwire::Oid> g_typemap = {
{LogicalTypeId::FLOAT, pgwire::Oid::Float4},
{LogicalTypeId::DOUBLE, pgwire::Oid::Float8},
// {LogicalTypeId::TINYINT, pgwire::Oid::Char},
{LogicalTypeId::SMALLINT, pgwire::Oid::Int2},
{LogicalTypeId::INTEGER, pgwire::Oid::Int4},
{LogicalTypeId::BIGINT, pgwire::Oid::Int8},
{LogicalTypeId::HUGEINT, pgwire::Oid::Int8},
// uses string
{LogicalTypeId::LIST, pgwire::Oid::Varchar},
{LogicalTypeId::VARCHAR, pgwire::Oid::Varchar},
{LogicalTypeId::DATE, pgwire::Oid::Date},
{LogicalTypeId::TIME, pgwire::Oid::Time},
{LogicalTypeId::TIMESTAMP, pgwire::Oid::Timestamp},
{LogicalTypeId::TIMESTAMP, pgwire::Oid::TimestampTz},
};
static pgwire::ParseHandler duckdb_handler(DatabaseInstance &db) {
return [&db](std::string const &query) mutable {
Connection conn(db);
pgwire::PreparedStatement stmt;
std::unique_ptr<PreparedStatement> prepared;
std::optional<pgwire::SqlException> error;
std::vector<std::string> column_names;
std::vector<LogicalType> column_types;
std::size_t column_total;
auto statements = conn.ExtractStatements(query);
auto statement_count = statements.size();
//cerr << statement_count << endl;
for (auto i = 0; i < statement_count - 1; ++i) {
auto pending = conn.PendingQuery(std::move(statements[i]), true);
// Return any error found before execution.
if (pending->HasError()) {
}
else
pending->Execute()->Print();
}
auto &statement_to_run = statements[statement_count - 1];
try {
prepared = conn.Prepare(std::move(statement_to_run));//(query);
if (!prepared) {
throw std::runtime_error(
"failed prepare query with unknown error");
}
if (prepared->HasError()) {
throw std::runtime_error(prepared->GetError());
}
column_names = prepared->GetNames();
column_types = prepared->GetTypes();
column_total = prepared->ColumnCount();
} catch (std::exception &e) {
error =
pgwire::SqlException{e.what(), pgwire::SqlState::DataException};
}
// rethrow error
if (error) {
throw *error;
}
stmt.fields.reserve(column_total);
for (std::size_t i = 0; i < column_total; i++) {
auto &name = column_names[i];
auto &type = column_types[i];
auto it = g_typemap.find(type.id());
if (it == g_typemap.end()) {
continue;
}
auto oid = it->second;
// can't uses emplace_back for POD struct in C++17
stmt.fields.push_back({name, oid});
}
stmt.handler = [column_total, p = std::move(prepared)](
pgwire::Writer &writer,
pgwire::Values const ¶meters) mutable {
std::unique_ptr<QueryResult> result;
std::optional<pgwire::SqlException> error;
try {
result = p->Execute();
if (!result) {
throw std::runtime_error(
"failed to execute query with unknown error");
}
if (result->HasError()) {
throw std::runtime_error(result->GetError());
}
} catch (std::exception &e) {
// std::cout << "error occured during execute:" << std::endl;
error = pgwire::SqlException{e.what(),
pgwire::SqlState::DataException};
}
if (error) {
throw *error;
}
auto &column_types = p->GetTypes();
for (auto &chunk : *result) {
auto row = writer.add_row();
for (std::size_t i = 0; i < column_total; i++) {
auto &type = column_types[i];
auto it = g_typemap.find(type.id());
if (it == g_typemap.end()) {
continue;
}
auto value = chunk.iterator.chunk->GetValue(i, chunk.row);
if (value.IsNull()) {
row.write_null();
continue;
}
//std::cerr <<"type.id()="<<(uint64_t)type.id()<<std::endl;
switch (type.id()) {
case LogicalTypeId::FLOAT:
row.write_float4(chunk.GetValue<float>(i));
break;
case LogicalTypeId::DOUBLE:
row.write_float8(chunk.GetValue<double>(i));
break;
case LogicalTypeId::SMALLINT:
row.write_int2(chunk.GetValue<int16_t>(i));
break;
case LogicalTypeId::INTEGER:
row.write_int4(chunk.GetValue<int32_t>(i));
break;
case LogicalTypeId::BIGINT:
row.write_int8(chunk.GetValue<int64_t>(i));
break;
case LogicalTypeId::HUGEINT:
//row.write_int16(dynamic_cast<__int128_t>(chunk.GetValue<hugeint_t>(i)));
row.write_int16(static_cast<__int128_t>(chunk.GetValue<hugeint_t>(i).upper) << 64 |
static_cast<__int128_t>(chunk.GetValue<hugeint_t>(i).lower));
break;
case LogicalTypeId::BOOLEAN:
row.write_bool(chunk.GetValue<bool>(i));
break;
case LogicalTypeId::LIST:
case LogicalTypeId::VARCHAR:
case LogicalTypeId::DATE:
case LogicalTypeId::TIME:
case LogicalTypeId::TIMESTAMP:
case LogicalTypeId::TIMESTAMP_TZ:
row.write_string(chunk.GetValue<std::string>(i));
break;
default:
break;
}
}
}
};
return stmt;
};
}
static void start_server(DatabaseInstance &db) {
using namespace asio;
if (g_started)
return;
g_started = true;
io_context io_context;
ip::tcp::endpoint endpoint(ip::tcp::v4(), 15432);
pgwire::log::initialize(io_context, "duckdb_pgwire.log");
pgwire::Server server(
io_context, endpoint,
[&db](pgwire::Session &sess) mutable { return duckdb_handler(db); });
server.start();
}
inline void PgIsInRecovery(DataChunk &args, ExpressionState &state,
Vector &result) {
result.SetValue(0, false);
}
inline void DuckdbPgwireScalarFun(DataChunk &args, ExpressionState &state, Vector &result) {
auto &name_vector = args.data[0];
UnaryExecutor::Execute<string_t, string_t>(
name_vector, result, args.size(),
[&](string_t name) {
return StringVector::AddString(result, "DuckdbPgwire "+name.GetString()+" 🐥");;
});
}
/*
static void LoadInternal(DatabaseInstance &instance) {
// Register a scalar function
auto pg_is_in_recovery_scalar_function = ScalarFunction(
"pg_is_in_recovery", {}, LogicalType::BOOLEAN, PgIsInRecovery);
ExtensionUtil::RegisterFunction(instance,
pg_is_in_recovery_scalar_function);
auto duckdb_pgwire_scalar_function = ScalarFunction("duckdb_pgwire", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DuckdbPgwireScalarFun);
ExtensionUtil::RegisterFunction(instance, duckdb_pgwire_scalar_function);
std::thread([&instance]() mutable { start_server(instance); }).detach();
}
void DuckdbPgwireExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }
std::string DuckdbPgwireExtension::Name() { return "duckdb_pgwire"; }
*/
static void LoadInternal(ExtensionLoader &loader) {
// Register a scalar function
auto pg_is_in_recovery_scalar_function = ScalarFunction("pg_is_in_recovery", {}, LogicalType::BOOLEAN, PgIsInRecovery);
loader.RegisterFunction(pg_is_in_recovery_scalar_function);
// Register another scalar function
auto duckdb_pgwire_scalar_function = ScalarFunction("duckdb_pgwire", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DuckdbPgwireScalarFun);
loader.RegisterFunction(duckdb_pgwire_scalar_function);
DatabaseInstance & instance=loader.GetDatabaseInstance();
std::thread([&instance]() mutable { start_server(instance); }).detach();
}
void DuckdbPgwireExtension::Load(ExtensionLoader &loader) {
LoadInternal(loader);
}
std::string DuckdbPgwireExtension::Name() {
return "duckdb_pgwire";
}
std::string DuckdbPgwireExtension::Version() const {
#ifdef EXT_VERSION_QUACK
return EXT_VERSION_QUACK;
#else
return "";
#endif
}
} // namespace duckdb
extern "C" {
/*
DUCKDB_EXTENSION_API void duckdb_pgwire_init(duckdb::DatabaseInstance &db) {
duckdb::DuckDB db_wrapper(db);
db_wrapper.LoadExtension<duckdb::DuckdbPgwireExtension>();
}
DUCKDB_EXTENSION_API const char *duckdb_pgwire_version() {
return duckdb::DuckDB::LibraryVersion();
}
*/
DUCKDB_CPP_EXTENSION_ENTRY(duckdb_pgwire, loader) {
duckdb::LoadInternal(loader);
}
}
#ifndef DUCKDB_EXTENSION_MAIN
#error DUCKDB_EXTENSION_MAIN not defined
#endif
执行结果如下
D CREATE TABLE cities ( country VARCHAR, name VARCHAR, year INTEGER, population INTEGER);
D INSERT INTO cities VALUES
('NL', 'Amsterdam', 2000, 1005),
('NL', 'Amsterdam', 2010, 1065),
('NL', 'Amsterdam', 2020, 1158),
('US', 'Seattle', 2000, 564),
('US', 'Seattle', 2010, 608),
('US', 'Seattle', 2020, 738),
('US', 'New York City', 2000, 8015),
('US', 'New York City', 2010, 8175),
('US', 'New York City', 2020, 8772);
main=> select 1::int128 a;
a
---
1
(1 row)
main=> PIVOT cities ON year USING first(population);
country | name | 2000 | 2010 | 2020
---------+---------------+------+------+------
NL | Amsterdam | 1005 | 1065 | 1158
US | New York City | 8015 | 8175 | 8772
US | Seattle | 564 | 608 | 738
(3 rows)
main=> select list(i)from range(4)t(i);
list(i)
--------------
[0, 1, 2, 3]
(1 row)
main=>


被折叠的 条评论
为什么被折叠?



