给duckdb_pgwire插件添加输出列表和PIVOT语句功能

第一个很简单,只要在g_typemap中添加一行{LogicalTypeId::LIST, pgwire::Oid::Varchar},再在duckdb_handler中在case LogicalTypeId::VARCHAR:前添加case LogicalTypeId::LIST:即可。

第二个比较棘手,PIVOT语句因为后台实际上会执行多个语句,无法使用prepare+excute语句执行,而只能用query执行,一开始试图查找语句的开头有PIVOT来分别采用两种方法执行,结果行不通,因为query查询的结果集没有GetNames等成员函数,无法获得相应的信息。将错误信息Invalid Input Error: Cannot prepare multiple statements at once在网上查找,结果在duckdb_ui插件的issue中找到了https://github.com/duckdb/duckdb-ui/issues/57。回帖采取的方法是用ExtractStatements方法把实际执行的语句提取出来,然后先执行前n-1个,最后用prepare执行最后1个,我试了,可以。完整修改后的duckdb_pgwire_extension.cpp如下:

#include "duckdb/common/types.hpp"
#include <unordered_map>
#define DUCKDB_EXTENSION_MAIN

#include <duckdb_pgwire_extension.hpp>

#include <duckdb/common/exception.hpp>
#include <duckdb/common/string_util.hpp>
#include <duckdb/function/scalar_function.hpp>
//#include <duckdb/main/extension_util.hpp>
#include <duckdb/main/extension/extension_loader.hpp>
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>

#include <atomic>
#include <optional>
#include <pgwire/exception.hpp>
#include <pgwire/log.hpp>
#include <pgwire/server.hpp>
#include <pgwire/types.hpp>
#include <stdexcept>

namespace duckdb {

static std::atomic<bool> g_started;

static std::unordered_map<LogicalTypeId, pgwire::Oid> g_typemap = {
    {LogicalTypeId::FLOAT, pgwire::Oid::Float4},
    {LogicalTypeId::DOUBLE, pgwire::Oid::Float8},
    // {LogicalTypeId::TINYINT, pgwire::Oid::Char},
    {LogicalTypeId::SMALLINT, pgwire::Oid::Int2},
    {LogicalTypeId::INTEGER, pgwire::Oid::Int4},
    {LogicalTypeId::BIGINT, pgwire::Oid::Int8},
    {LogicalTypeId::HUGEINT, pgwire::Oid::Int8},
    // uses string
    {LogicalTypeId::LIST, pgwire::Oid::Varchar},
    {LogicalTypeId::VARCHAR, pgwire::Oid::Varchar},
    {LogicalTypeId::DATE, pgwire::Oid::Date},
    {LogicalTypeId::TIME, pgwire::Oid::Time},
    {LogicalTypeId::TIMESTAMP, pgwire::Oid::Timestamp},
    {LogicalTypeId::TIMESTAMP, pgwire::Oid::TimestampTz},
};

static pgwire::ParseHandler duckdb_handler(DatabaseInstance &db) {
    return [&db](std::string const &query) mutable {
        Connection conn(db);
        pgwire::PreparedStatement stmt;
        std::unique_ptr<PreparedStatement> prepared;
        std::optional<pgwire::SqlException> error;

        std::vector<std::string> column_names;
        std::vector<LogicalType> column_types;
        std::size_t column_total;

      auto statements = conn.ExtractStatements(query);

      auto statement_count = statements.size();
      //cerr << statement_count << endl;
  
      for (auto i = 0; i < statement_count - 1; ++i) {
      auto pending = conn.PendingQuery(std::move(statements[i]), true);
      // Return any error found before execution.
      if (pending->HasError()) {
      }
      else
      pending->Execute()->Print();
     }

     auto &statement_to_run = statements[statement_count - 1];
        try {
            prepared = conn.Prepare(std::move(statement_to_run));//(query);
            if (!prepared) {
                throw std::runtime_error(
                    "failed prepare query with unknown error");
            }

            if (prepared->HasError()) {
                throw std::runtime_error(prepared->GetError());
            }

            column_names = prepared->GetNames();
            column_types = prepared->GetTypes();
            column_total = prepared->ColumnCount();

        } catch (std::exception &e) {
            error =
                pgwire::SqlException{e.what(), pgwire::SqlState::DataException};
        }

        // rethrow error
        if (error) {
            throw *error;
        }

        stmt.fields.reserve(column_total);
        for (std::size_t i = 0; i < column_total; i++) {
            auto &name = column_names[i];
            auto &type = column_types[i];

            auto it = g_typemap.find(type.id());
            if (it == g_typemap.end()) {
                continue;
            }
            auto oid = it->second;

            // can't uses emplace_back for POD struct in C++17
            stmt.fields.push_back({name, oid});
        }

        stmt.handler = [column_total, p = std::move(prepared)](
                           pgwire::Writer &writer,
                           pgwire::Values const &parameters) mutable {
            std::unique_ptr<QueryResult> result;
            std::optional<pgwire::SqlException> error;

            try {
                result = p->Execute();

                if (!result) {
                    throw std::runtime_error(
                        "failed to execute query with unknown error");
                }

                if (result->HasError()) {
                    throw std::runtime_error(result->GetError());
                }

            } catch (std::exception &e) {
                // std::cout << "error occured during execute:" << std::endl;
                error = pgwire::SqlException{e.what(),
                                             pgwire::SqlState::DataException};
            }

            if (error) {
                throw *error;
            }

            auto &column_types = p->GetTypes();

            for (auto &chunk : *result) {
                auto row = writer.add_row();

                for (std::size_t i = 0; i < column_total; i++) {
                    auto &type = column_types[i];

                    auto it = g_typemap.find(type.id());
                    if (it == g_typemap.end()) {
                        continue;
                    }

                    auto value = chunk.iterator.chunk->GetValue(i, chunk.row);
                    if (value.IsNull()) {
                        row.write_null();
                        continue;
                    }
                    //std::cerr <<"type.id()="<<(uint64_t)type.id()<<std::endl;
                    switch (type.id()) {
                    case LogicalTypeId::FLOAT:
                        row.write_float4(chunk.GetValue<float>(i));
                        break;
                    case LogicalTypeId::DOUBLE:
                        row.write_float8(chunk.GetValue<double>(i));
                        break;
                    case LogicalTypeId::SMALLINT:
                        row.write_int2(chunk.GetValue<int16_t>(i));
                        break;
                    case LogicalTypeId::INTEGER:
                        row.write_int4(chunk.GetValue<int32_t>(i));
                        break;
                    case LogicalTypeId::BIGINT:
                        row.write_int8(chunk.GetValue<int64_t>(i));
                        break;
                    case LogicalTypeId::HUGEINT:
                        //row.write_int16(dynamic_cast<__int128_t>(chunk.GetValue<hugeint_t>(i)));
                        row.write_int16(static_cast<__int128_t>(chunk.GetValue<hugeint_t>(i).upper) << 64 | 
                static_cast<__int128_t>(chunk.GetValue<hugeint_t>(i).lower));
                        break;
                    case LogicalTypeId::BOOLEAN:
                        row.write_bool(chunk.GetValue<bool>(i));
                        break;
                    case LogicalTypeId::LIST:
                    case LogicalTypeId::VARCHAR:
                    case LogicalTypeId::DATE:
                    case LogicalTypeId::TIME:
                    case LogicalTypeId::TIMESTAMP:
                    case LogicalTypeId::TIMESTAMP_TZ:
                        row.write_string(chunk.GetValue<std::string>(i));
                        break;
                    default:
                        break;
                    }
                }
            }
        };
        return stmt;
    };
}

static void start_server(DatabaseInstance &db) {
    using namespace asio;
    if (g_started)
        return;

    g_started = true;

    io_context io_context;
    ip::tcp::endpoint endpoint(ip::tcp::v4(), 15432);

    pgwire::log::initialize(io_context, "duckdb_pgwire.log");

    pgwire::Server server(
        io_context, endpoint,
        [&db](pgwire::Session &sess) mutable { return duckdb_handler(db); });
    server.start();
}

inline void PgIsInRecovery(DataChunk &args, ExpressionState &state,
                           Vector &result) {
    result.SetValue(0, false);
}

inline void DuckdbPgwireScalarFun(DataChunk &args, ExpressionState &state, Vector &result) {
    auto &name_vector = args.data[0];
    UnaryExecutor::Execute<string_t, string_t>(
	    name_vector, result, args.size(),
	    [&](string_t name) {
			return StringVector::AddString(result, "DuckdbPgwire "+name.GetString()+" 🐥");;
        });
}
/*
static void LoadInternal(DatabaseInstance &instance) {
    // Register a scalar function
    auto pg_is_in_recovery_scalar_function = ScalarFunction(
        "pg_is_in_recovery", {}, LogicalType::BOOLEAN, PgIsInRecovery);
    ExtensionUtil::RegisterFunction(instance,
                                    pg_is_in_recovery_scalar_function);

    auto duckdb_pgwire_scalar_function = ScalarFunction("duckdb_pgwire", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DuckdbPgwireScalarFun);
    ExtensionUtil::RegisterFunction(instance, duckdb_pgwire_scalar_function);

    std::thread([&instance]() mutable { start_server(instance); }).detach();
}

void DuckdbPgwireExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }
std::string DuckdbPgwireExtension::Name() { return "duckdb_pgwire"; }
*/

static void LoadInternal(ExtensionLoader &loader) {
	// Register a scalar function
	auto pg_is_in_recovery_scalar_function = ScalarFunction("pg_is_in_recovery", {}, LogicalType::BOOLEAN, PgIsInRecovery);
	loader.RegisterFunction(pg_is_in_recovery_scalar_function);

	// Register another scalar function
	auto duckdb_pgwire_scalar_function = ScalarFunction("duckdb_pgwire", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DuckdbPgwireScalarFun);
	loader.RegisterFunction(duckdb_pgwire_scalar_function);
	
	DatabaseInstance & instance=loader.GetDatabaseInstance();
	std::thread([&instance]() mutable { start_server(instance); }).detach();
}

void DuckdbPgwireExtension::Load(ExtensionLoader &loader) {
	LoadInternal(loader);
}
std::string DuckdbPgwireExtension::Name() {
	return "duckdb_pgwire";
}
std::string DuckdbPgwireExtension::Version() const {
#ifdef EXT_VERSION_QUACK
	return EXT_VERSION_QUACK;
#else
	return "";
#endif
}

} // namespace duckdb

extern "C" {
/*
DUCKDB_EXTENSION_API void duckdb_pgwire_init(duckdb::DatabaseInstance &db) {
    duckdb::DuckDB db_wrapper(db);
    db_wrapper.LoadExtension<duckdb::DuckdbPgwireExtension>();
}

DUCKDB_EXTENSION_API const char *duckdb_pgwire_version() {
    return duckdb::DuckDB::LibraryVersion();
}
*/
DUCKDB_CPP_EXTENSION_ENTRY(duckdb_pgwire, loader) {
	duckdb::LoadInternal(loader);
}
}

#ifndef DUCKDB_EXTENSION_MAIN
#error DUCKDB_EXTENSION_MAIN not defined
#endif

执行结果如下

D CREATE TABLE cities (    country VARCHAR, name VARCHAR, year INTEGER, population INTEGER);
D INSERT INTO cities VALUES
      ('NL', 'Amsterdam', 2000, 1005),
      ('NL', 'Amsterdam', 2010, 1065),
      ('NL', 'Amsterdam', 2020, 1158),
      ('US', 'Seattle', 2000, 564),
      ('US', 'Seattle', 2010, 608),
      ('US', 'Seattle', 2020, 738),
      ('US', 'New York City', 2000, 8015),
      ('US', 'New York City', 2010, 8175),
      ('US', 'New York City', 2020, 8772);
main=> select 1::int128 a;
 a
---
 1
(1 row)

main=> PIVOT cities ON year USING first(population);
 country |     name      | 2000 | 2010 | 2020
---------+---------------+------+------+------
 NL      | Amsterdam     | 1005 | 1065 | 1158
 US      | New York City | 8015 | 8175 | 8772
 US      | Seattle       |  564 |  608 |  738
(3 rows)


main=> select list(i)from range(4)t(i);
   list(i)
--------------
 [0, 1, 2, 3]
(1 row)

main=>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值