Create new tool for CSV

本文详细介绍了CSVFileStream和CsvFileWriter类的功能,包括如何使用这些类进行CSV文件的读取、写入和解析。重点阐述了如何通过CSVFileStream类获取CSV文件的每一行数据,并通过CsvFileWriter类实现CSV文件的高效写入与格式化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CsvFileStream.cs

public class CsvFileStream
    {
        TextReader stream;
        bool EOS = false;
        bool EOL = false;

        public CsvFileStream(TextReader s)
        {
            stream = s;
        }

        public string[] GetNextRow()
        {
            ArrayList row = new ArrayList();
            while (true)
            {
                string item = GetNextItem();
                if (item == null)
                    return row.Count == 0 ? null : (string[])row.ToArray(typeof(string));
                row.Add(item);
            }
        }

        string GetNextItem()
        {
            if (EOL)
            {
                // previous item was last in line, start new line
                EOL = false;
                return null;
            }

            bool quoted = false;
            bool predata = true;
            bool postdata = false;
            StringBuilder sb = new StringBuilder();

            while (true)
            {
                char c = GetNextChar(true);
                if (EOS)
                    return sb.Length > 0 ? sb.ToString() : null;

                if ((postdata || !quoted) && c == ',')
                    // end of item, return
                    return sb.ToString();

                if ((predata || postdata || !quoted) && (c == '\x0A' || c == '\x0D'))
                {
                    // we are at the end of the line, eat newline characters and exit
                    EOL = true;
                    if (c == '\x0D' && GetNextChar(false) == '\x0A')
                        // new line sequence is 0D0A
                        GetNextChar(true);
                    return sb.ToString();
                }

                if (predata && c == ' ')
                    // whitespace preceeding data, discard
                    continue;

                if (predata && c == '"')
                {
                    // quoted data is starting
                    quoted = true;
                    predata = false;
                    continue;
                }

                if (predata)
                {
                    // data is starting without quotes
                    predata = false;
                    sb.Append(c);
                    continue;
                }

                if (c == '"' && quoted)
                {
                    if (GetNextChar(false) == '"')
                        // double quotes within quoted string means add a quote       
                        sb.Append(GetNextChar(true));
                    else
                        // end-quote reached
                        postdata = true;
                    continue;
                }

                // all cases covered, character must be data
                sb.Append(c);
            }
        }

        char[] buffer = new char[4096];
        int pos = 0;
        int length = 0;

        char GetNextChar(bool eat)
        {
            if (pos >= length)
            {
                length = stream.ReadBlock(buffer, 0, buffer.Length);
                if (length == 0)
                {
                    EOS = true;
                    return '\0';
                }
                pos = 0;
            }
            if (eat)
                return buffer[pos++];
            else
                return buffer[pos];
        }
    }

CsvFileWriter.cs

public class CsvFileWriter
    {
        private ArrayList rowAL;        //行链表,CSV文件的每一行就是一个链
        private string fileName;       //文件名
        private Encoding encoding;       //编码

        public CsvFileWriter()
        {
            this.rowAL = new ArrayList();
            this.fileName = "";
            this.encoding = Encoding.Default;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="fileName">文件名,包括文件路径</param>
        public CsvFileWriter(string fileName)
        {
            this.rowAL = new ArrayList();
            this.fileName = fileName;
            this.encoding = Encoding.Default;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="fileName">文件名,包括文件路径</param>
        /// <param name="encoding">文件编码</param>
        public CsvFileWriter(string fileName, Encoding encoding)
        {
            this.rowAL = new ArrayList();
            this.fileName = fileName;
            this.encoding = encoding;
        }

        /// <summary>
        /// row:行,row = 1代表第一行
        /// col:列,col = 1代表第一列
        /// </summary>
        public string this[int row, int col]
        {
            set
            {
                //对行进行判断
                if (row <= 0)
                {
                    throw new Exception("行数不能小于0");
                }
                else if (row > this.rowAL.Count) //如果当前列链的行数不够,要补齐
                {
                    for (int i = this.rowAL.Count + 1; i <= row; i++)
                    {
                        this.rowAL.Add(new ArrayList());
                    }
                }
                else
                {
                }
                //对列进行判断
                if (col <= 0)
                {
                    throw new Exception("列数不能小于0");
                }
                else
                {
                    ArrayList colTempAL = (ArrayList)this.rowAL[row - 1];

                    //扩大长度
                    if (col > colTempAL.Count)
                    {
                        for (int i = colTempAL.Count; i <= col; i++)
                        {
                            colTempAL.Add("");
                        }
                    }
                    this.rowAL[row - 1] = colTempAL;
                }
                //赋值
                ArrayList colAL = (ArrayList)this.rowAL[row - 1];

                colAL[col - 1] = value;
                this.rowAL[row - 1] = colAL;
            }
        }

        /// <summary>
        /// 文件名,包括文件路径
        /// </summary>
        public string FileName
        {
            set
            {
                this.fileName = value;
            }
        }

        /// <summary>
        /// 文件编码
        /// </summary>

        public Encoding FileEncoding
        {
            set
            {
                this.encoding = value;
            }
        }

        /// <summary>
        /// 获取当前最大行
        /// </summary>
        public int CurMaxRow
        {
            get
            {
                return this.rowAL.Count;
            }
        }

        /// <summary>
        /// 获取最大列
        /// </summary>
        public int CurMaxCol
        {
            get
            {
                int maxCol;

                maxCol = 0;
                for (int i = 0; i < this.rowAL.Count; i++)
                {
                    ArrayList colAL = (ArrayList)this.rowAL[i];

                    maxCol = (maxCol > colAL.Count) ? maxCol : colAL.Count;
                }

                return maxCol;
            }
        }

        /// <summary>
        /// 添加表数据到CSV文件中
        /// </summary>
        /// <param name="dataDT">表数据</param>
        /// <param name="beginCol">从第几列开始,beginCol = 1代表第一列</param>
        public void AddData(DataTable dataDT, int beginCol)
        {
            if (dataDT == null)
            {
                throw new Exception("需要添加的表数据为空");
            }
            int curMaxRow;

            curMaxRow = this.rowAL.Count;
            for (int i = 0; i < dataDT.Rows.Count; i++)
            {
                for (int j = 0; j < dataDT.Columns.Count; j++)
                {
                    this[curMaxRow + i + 1, beginCol + j] = dataDT.Rows[i][j].ToString();
                }
            }
        }

        /// <summary>
        /// 保存数据,如果当前硬盘中已经存在文件名一样的文件,将会覆盖
        /// </summary>
        public void Save()
        {
            //对数据的有效性进行判断
            if (this.fileName == null)
            {
                throw new Exception("缺少文件名");
            }
            else if (File.Exists(this.fileName))
            {
                File.Delete(this.fileName);
            }
            if (this.encoding == null)
            {
                this.encoding = Encoding.Default;
            }
            System.IO.StreamWriter sw = new StreamWriter(this.fileName, false, this.encoding);

            for (int i = 0; i < this.rowAL.Count; i++)
            {
                sw.WriteLine(ConvertToSaveLine((ArrayList)this.rowAL[i]));
            }

            sw.Close();
        }

        /// <summary>
        /// 保存数据,如果当前硬盘中已经存在文件名一样的文件,将会覆盖
        /// </summary>
        /// <param name="fileName">文件名,包括文件路径</param>
        public void Save(string fileName)
        {
            this.fileName = fileName;
            Save();
        }

        /// <summary>
        /// 保存数据,如果当前硬盘中已经存在文件名一样的文件,将会覆盖
        /// </summary>
        /// <param name="fileName">文件名,包括文件路径</param>
        /// <param name="encoding">文件编码</param>
        public void Save(string fileName, Encoding encoding)
        {
            this.fileName = fileName;
            this.encoding = encoding;
            Save();
        }

        /// <summary>
        /// 转换成保存行
        /// </summary>
        /// <param name="colAL">一行</param>
        /// <returns></returns>
        private string ConvertToSaveLine(ArrayList colAL)
        {
            string saveLine;

            saveLine = "";
            for (int i = 0; i < colAL.Count; i++)
            {
                saveLine += ConvertToSaveCell(colAL[i].ToString());
                //格子间以逗号分割
                if (i < colAL.Count - 1)
                {
                    saveLine += ",";
                }
            }

            return saveLine;
        }

        /// <summary>
        /// 字符串转换成CSV中的格子
        /// 双引号转换成两个双引号,然后首尾各加一个双引号
        /// 这样就不需要考虑逗号及换行的问题
        /// </summary>
        /// <param name="cell">格子内容</param>
        /// <returns></returns>
        private string ConvertToSaveCell(string cell)
        {
            cell = cell.Replace("\"","\"\"");

            return "\"" + cell + "\"";
            //return "";
        }
    }

CsvFileReader.cs

public class CsvFileReader
    {
        public static DataTable Parse(string data, bool headers)
        {
            return Parse(new StringReader(data), headers);
        }

        public static DataTable Parse(string data)
        {
            return Parse(new StringReader(data));
        }

        public static DataTable Parse(TextReader stream)
        {
            return Parse(stream, false);
        }

        public static DataTable Parse(TextReader stream, bool headers)
        {
            DataTable table = new DataTable();

            CsvFileStream csv = new CsvFileStream(stream);
            string[] row = csv.GetNextRow();
            if (row == null) return null;

            if (headers)
            {
                foreach (string header in row)
                {
                    if (header != null && header.Length > 0 && !table.Columns.Contains(header))
                        table.Columns.Add(header, typeof(string));
                    else
                        table.Columns.Add(GetNextColumnHeader(table), typeof(string));
                }
                row = csv.GetNextRow();
            }

            while (row != null)
            {
                while (row.Length > table.Columns.Count)
                    table.Columns.Add(GetNextColumnHeader(table), typeof(string));
                table.Rows.Add(row);
                row = csv.GetNextRow();
            }
            return table;
        }

        static string GetNextColumnHeader(DataTable table)
        {
            int c = 1;
            while (true)
            {
                string h = "Column" + c++;
                if (!table.Columns.Contains(h))
                    return h;
            }
        }
    }

eg.

TextReader reader = new StreamReader(@"D:\test.csv");
DataTable dt = CsvFileReader.Parse(reader, false);

 

转载于:https://www.cnblogs.com/vincentDr/p/3600664.html

""" title: SQL Server Access author: MENG author_urls: - https://github.com/mengvision description: A tool for reading database information and executing SQL queries, supporting multiple databases such as MySQL, PostgreSQL, SQLite, and Oracle. It provides functionalities for listing all tables, describing table schemas, and returning query results in CSV format. A versatile DB Agent for seamless database interactions. required_open_webui_version: 0.5.4 requirements: pymysql, sqlalchemy, cx_Oracle version: 0.1.6 licence: MIT # Changelog ## [0.1.6] - 2025-03-11 ### Added - Added `get_table_indexes` method to retrieve index information for a specific table, supporting MySQL, PostgreSQL, SQLite, and Oracle. - Enhanced metadata capabilities by providing detailed index descriptions (e.g., index name, columns, and type). - Improved documentation to include the new `get_table_indexes` method and its usage examples. - Updated error handling in `get_table_indexes` to provide more detailed feedback for unsupported database types. ## [0.1.5] - 2025-01-20 ### Changed - Updated `list_all_tables` and `table_data_schema` methods to accept `db_name` as a function parameter instead of using `self.valves.db_name`. - Improved flexibility by decoupling database name from class variables, allowing dynamic database selection at runtime. ## [0.1.4] - 2025-01-17 ### Added - Added support for Oracle database using `cx_Oracle` driver. - Added dynamic engine creation in each method to ensure fresh database connections for every operation. - Added support for Oracle-specific queries in `list_all_tables` and `table_data_schema` methods. ### Changed - Moved `self._get_engine()` from `__init__` to individual methods for better flexibility and tool compatibility. - Updated `_get_engine` method to support Oracle database connection URL. - Improved `table_data_schema` method to handle Oracle-specific column metadata. ### Fixed - Fixed potential connection issues by ensuring each method creates its own database engine. - Improved error handling for Oracle-specific queries and edge cases. ## [0.1.3] - 2025-01-17 ### Added - Added support for multiple database types (e.g., MySQL, PostgreSQL, SQLite) using SQLAlchemy. - Added configuration flexibility through environment variables or external configuration files. - Enhanced query security with stricter validation and SQL injection prevention. - Improved error handling with detailed exception messages for better debugging. ### Changed - Replaced `pymysql` with SQLAlchemy for broader database compatibility. - Abstracted database connection logic into a reusable `_get_engine` method. - Updated `table_data_schema` method to support multiple database types. ### Fixed - Fixed potential SQL injection vulnerabilities in query execution. - Improved handling of edge cases in query validation and execution. ## [0.1.2] - 2025-01-16 ### Added - Added support for specifying the database port with a default value of `3306`. - Abstracted database connection logic into a reusable `_get_connection` method. ## [0.1.1] - 2025-01-16 ### Added - Support for additional read-only query types: `SHOW`, `DESCRIBE`, `EXPLAIN`, and `USE`. - Enhanced query validation to block sensitive keywords (e.g., `INSERT`, `UPDATE`, `DELETE`, `CREATE`, `DROP`, `ALTER`). ### Fixed - Improved handling of queries starting with `WITH` (CTE queries). - Fixed case sensitivity issues in query validation. ## [0.1.0] - 2025-01-09 ### Initial Release - Basic functionality for listing tables, describing table schemas, and executing `SELECT` queries. - Query results returned in CSV format. """ import os from typing import List, Dict, Any from pydantic import BaseModel, Field import re from sqlalchemy import create_engine, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError class Tools: class Valves(BaseModel): db_host: str = Field( default="localhost", description="The host of the database. Replace with your own host.", ) db_user: str = Field( default="admin", description="The username for the database. Replace with your own username.", ) db_password: str = Field( default="admin", description="The password for the database. Replace with your own password.", ) db_name: str = Field( default="db", description="The name of the database. Replace with your own database name.", ) db_port: int = Field( default=3306, # Oracle 默认端口 description="The port of the database. Replace with your own port.", ) db_type: str = Field( default="mysql", description="The type of the database (e.g., mysql, postgresql, sqlite, oracle).", ) def __init__(self): """ Initialize the Tools class with the credentials for the database. """ print("Initializing database tool class") self.citation = True self.valves = Tools.Valves() def _get_engine(self) -> Engine: """ Create and return a database engine using the current configuration. """ if self.valves.db_type == "mysql": db_url = f"mysql+pymysql://{self.valves.db_user}:{self.valves.db_password}@{self.valves.db_host}:{self.valves.db_port}/{self.valves.db_name}" elif self.valves.db_type == "postgresql": db_url = f"postgresql://{self.valves.db_user}:{self.valves.db_password}@{self.valves.db_host}:{self.valves.db_port}/{self.valves.db_name}" elif self.valves.db_type == "sqlite": db_url = f"sqlite:///{self.valves.db_name}" elif self.valves.db_type == "oracle": db_url = f"oracle+cx_oracle://{self.valves.db_user}:{self.valves.db_password}@{self.valves.db_host}:{self.valves.db_port}/?service_name={self.valves.db_name}" else: raise ValueError(f"Unsupported database type: {self.valves.db_type}") return create_engine(db_url) def list_all_tables(self, db_name: str) -> str: """ List all tables in the database. :param db_name: The name of the database. :return: A string containing the names of all tables. """ print("Listing all tables in the database") engine = self._get_engine() # 动态创建引擎 try: with engine.connect() as conn: if self.valves.db_type == "mysql": result = conn.execute(text("SHOW TABLES;")) elif self.valves.db_type == "postgresql": result = conn.execute( text( "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" ) ) elif self.valves.db_type == "sqlite": result = conn.execute( text("SELECT name FROM sqlite_master WHERE type='table';") ) elif self.valves.db_type == "oracle": result = conn.execute(text("SELECT table_name FROM user_tables;")) else: return "Unsupported database type." tables = [row[0] for row in result.fetchall()] if tables: return ( "Here is a list of all the tables in the database:\n\n" + "\n".join(tables) ) else: return "No tables found." except SQLAlchemyError as e: return f"Error listing tables: {str(e)}" def get_table_indexes(self, db_name: str, table_name: str) -> str: """ Get the indexes of a specific table in the database. :param db_name: The name of the database. :param table_name: The name of the table. :return: A string describing the indexes of the table. """ print(f"Getting indexes for table: {table_name}") engine = self._get_engine() try: key, cloumn = 0, 1 with engine.connect() as conn: if self.valves.db_type == "mysql": query = text(f"SHOW INDEX FROM {table_name}") key, cloumn = 2, 4 elif self.valves.db_type == "postgresql": query = text( """ SELECT indexname, indexdef FROM pg_indexes WHERE tablename = :table_name; """ ) elif self.valves.db_type == "sqlite": query = text( """ PRAGMA index_list(:table_name); """ ) elif self.valves.db_type == "oracle": query = text( """ SELECT index_name, column_name FROM user_ind_columns WHERE table_name = :table_name; """ ) else: return "Unsupported database type." result = conn.execute(query) indexes = result.fetchall() if not indexes: return f"No indexes found for table: {table_name}" description = f"Indexes for table '{table_name}':\n" for index in indexes: description += f"- {index[key]}: {index[cloumn]}\n" return description # result = conn.execute(query) # description = result.fetchall() # if not description: # return f"No indexes found for table: {table_name}" # column_names = result.keys() # description = f"Query executed successfully. Below is the actual result of the query {query} running against the database in CSV format:\n\n" # description += ",".join(column_names) + "\n" # for row in description: # description += ",".join(map(str, row)) + "\n" # return description except SQLAlchemyError as e: return f"Error getting indexes: {str(e)}" def table_data_schema(self, db_name: str, table_name: str) -> str: """ Describe the schema of a specific table in the database, including column comments. :param db_name: The name of the database. :param table_name: The name of the table to describe. :return: A string describing the data schema of the table. """ print(f"Database: {self.valves.db_name}") print(f"Describing table: {table_name}") engine = self._get_engine() # 动态创建引擎 try: with engine.connect() as conn: if self.valves.db_type == "mysql": query = text( " SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_COMMENT " " FROM INFORMATION_SCHEMA.COLUMNS " f" WHERE TABLE_SCHEMA = '{self.valves.db_name}' AND TABLE_NAME = '{table_name}';" ) elif self.valves.db_type == "postgresql": query = text( """ SELECT column_name, data_type, is_nullable, column_default, '' FROM information_schema.columns WHERE table_name = :table_name; """ ) elif self.valves.db_type == "sqlite": query = text("PRAGMA table_info(:table_name);") elif self.valves.db_type == "oracle": query = text( """ SELECT column_name, data_type, nullable, data_default, comments FROM user_tab_columns LEFT JOIN user_col_comments ON user_tab_columns.table_name = user_col_comments.table_name AND user_tab_columns.column_name = user_col_comments.column_name WHERE user_tab_columns.table_name = :table_name; """ ) else: return "Unsupported database type." # result = conn.execute( # query, {"db_name": db_name, "table_name": table_name} # ) result = conn.execute(query) columns = result.fetchall() if not columns: return f"No such table: {table_name}" description = ( f"Table '{table_name}' in the database has the following columns:\n" ) for column in columns: if self.valves.db_type == "sqlite": column_name, data_type, is_nullable, _, _, _ = column column_comment = "" elif self.valves.db_type == "oracle": ( column_name, data_type, is_nullable, data_default, column_comment, ) = column else: ( column_name, data_type, is_nullable, column_key, column_comment, ) = column description += f"- {column_name} ({data_type})" if is_nullable == "YES" or is_nullable == "Y": description += " [Nullable]" if column_key == "PRI": description += " [Primary Key]" if column_comment: description += f" [Comment: {column_comment}]" description += "\n" return description except SQLAlchemyError as e: return f"Error describing table: {str(e)}" def execute_read_query(self, query: str) -> str: """ Execute a read query and return the result in CSV format. :param query: The SQL query to execute. :return: A string containing the result of the query in CSV format. """ print(f"Executing query: {query}") normalized_query = query.strip().lower() if not re.match( r"^\s*(select|with|show|describe|desc|explain|use)\s", normalized_query ): return "Error: Only read-only queries (SELECT, WITH, SHOW, DESCRIBE, EXPLAIN, USE) are allowed. CREATE, DELETE, INSERT, UPDATE, DROP, and ALTER operations are not permitted." sensitive_keywords = [ "insert", "update", "delete", "create", "drop", "alter", "truncate", "grant", "revoke", "replace", ] for keyword in sensitive_keywords: if re.search(rf"\b{keyword}\b", normalized_query): return f"Error: Query contains a sensitive keyword '{keyword}'. Only read operations are allowed." engine = self._get_engine() # 动态创建引擎 try: with engine.connect() as conn: result = conn.execute(text(query)) rows = result.fetchall() if not rows: return "No data returned from query." column_names = result.keys() csv_data = f"Query executed successfully. Below is the actual result of the query {query} running against the database in CSV format:\n\n" csv_data += ",".join(column_names) + "\n" for row in rows: csv_data += ",".join(map(str, row)) + "\n" return csv_data except SQLAlchemyError as e: return f"Error executing query: {str(e)}" 将上面的工具连接到Microsoft sql server
07-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值