TableRow中child view的match parent失效的问题

在Android开发中,遇到TableRow组件内多个子视图时,使用match_parent属性无法生效的情况。通过设置子视图的android:layout_weight为1,可以解决此问题。

参考:http://stackoverflow.com/questions/10231026/android-tablerow-children-issue


在sdk2.2上,发现TableRow里面放入多个child view的时候,第二个view设置match_parent的属性却无效。

根据上述连接的说法,给chile view设置android:layout_weight="1"即可解决。

package com.example.kucun2.ui.dingdan;//package com.example.kucun2; import static android.content.ContentValues.TAG; import android.animation.Animator; import android.animation.AnimatorListenerAdapter; import android.animation.ObjectAnimator; import android.animation.ValueAnimator; import android.app.AlertDialog; import android.content.Context; import android.graphics.Color; import android.graphics.Typeface; import android.os.Bundle; import android.util.Log; import android.util.TypedValue; import android.view.Gravity; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.view.ViewTreeObserver; import android.widget.AdapterView; import android.widget.ArrayAdapter; import android.widget.Button; import android.widget.FrameLayout; import android.widget.HorizontalScrollView; import android.widget.LinearLayout; import android.widget.PopupMenu; import android.widget.SearchView; import android.widget.Spinner; import android.widget.TableLayout; import android.widget.TableRow; import android.widget.TextView; import android.widget.Toast; import androidx.annotation.NonNull; import androidx.core.content.ContextCompat; import androidx.fragment.app.Fragment; import com.example.kucun2.MainActivity; import com.example.kucun2.R; import com.example.kucun2.View.HorizontalScrollTextView; import com.example.kucun2.entity.Bancai; import com.example.kucun2.entity.Chanpin; import com.example.kucun2.entity.Chanpin_Zujian; import com.example.kucun2.entity.Dingdan; import com.example.kucun2.entity.Dingdan_chanpin_zujian; import com.example.kucun2.entity.Dingdan_Chanpin; import com.example.kucun2.entity.data.Data; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; public class OrderDisplayFragment extends Fragment { private TableLayout table; private HorizontalScrollView horizontalScrollView; private ValueAnimator scrollIndicatorAnimator; private boolean isIndicatorVisible = false; // 添加排序相关的成员变量 private int currentSortColumn = -1; private boolean sortAscending = true; private List<Object[]> allTableRowsData = new ArrayList<>(); // 添加搜索相关成员变量 private SearchView searchView; private Spinner columnSelector; private List<Object[]> filteredTableRowsData = new ArrayList<>(); private boolean isDataLoaded = false; /** * 加载初始化 * * @param inflater The LayoutInflater object that can be used to inflate * any views in the fragment, * @param container If non-null, this is the parent view that the fragment's * UI should be attached to. The fragment should not add the view itself, * but this can be used to generate the LayoutParams of the view. * @param savedInstanceState If non-null, this fragment is being re-constructed * from a previous saved state as given here. * @return */ @Override public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) { View view = inflater.inflate(R.layout.fragment_order_display, container, false); table = view.findViewById(R.id.orderTable); horizontalScrollView = view.findViewById(R.id.horizontalScrollContainer); View scrollIndicator = view.findViewById(R.id.scroll_indicator); // 获取搜索控件 searchView = view.findViewById(R.id.search_view); columnSelector = view.findViewById(R.id.column_selector); // 初始化表头选择器 initColumnSelector(); // 设置搜索监听 searchView.setOnQueryTextListener(new SearchView.OnQueryTextListener() { @Override public boolean onQueryTextSubmit(String query) { applySearchFilter(); return true; } @Override public boolean onQueryTextChange(String newText) { applySearchFilter(); return true; } }); LinearLayout fixedSearchBar = view.findViewById(R.id.fixedSearchBar); View placeholder = view.findViewById(R.id.search_bar_placeholder); // 添加全局布局监听器以获取正确的搜索框高度 fixedSearchBar.getViewTreeObserver().addOnGlobalLayoutListener( new ViewTreeObserver.OnGlobalLayoutListener() { @Override public void onGlobalLayout() { // 获取搜索框的实际高度 int searchBarHeight = fixedSearchBar.getHeight(); // 设置占位视图的高度 ViewGroup.LayoutParams params = placeholder.getLayoutParams(); params.height = searchBarHeight; placeholder.setLayoutParams(params); // 确保仅运行一次 fixedSearchBar.getViewTreeObserver().removeOnGlobalLayoutListener(this); } } ); Log.d(TAG, "onCreateView: " + Data.dingdans.get(0).getNumber()); // 添加表头 addTableHeader(table); // 检查数据是否已加载 if (Data.dingdans.isEmpty()) { // 显示加载指示器 showLoadingIndicator(); // 设置数据加载监听器 if (getActivity() instanceof MainActivity) { ((MainActivity) getActivity()).setOnDataLoadListener(new MainActivity.OnDataLoadListener() { @Override public void onDataLoaded() { requireActivity().runOnUiThread(() -> { hideLoadingIndicator(); isDataLoaded = true; fillTableData(); // 填充数据 }); } @Override public void onDataError() { requireActivity().runOnUiThread(() -> { hideLoadingIndicator(); Toast.makeText(getContext(), "检查网络", Toast.LENGTH_SHORT).show(); //showError("数据加载失败"); }); } }); } } else { // 数据已加载,直接填充 fillTableData(); isDataLoaded = true; } // 填充表格数据 // fillTableData(); // 添加滚动监听 horizontalScrollView.getViewTreeObserver().addOnScrollChangedListener(() -> { int maxScroll = horizontalScrollView.getChildAt(0).getWidth() - horizontalScrollView.getWidth(); int currentScroll = horizontalScrollView.getScrollX(); if (currentScroll > 0 && maxScroll > 0) { if (!isIndicatorVisible) { showScrollIndicator(); } // 更新滚动指示器位置 updateScrollIndicatorPosition(currentScroll, maxScroll); } else { hideScrollIndicator(); } }); return view; } // 显示/隐藏加载指示器的方法 private void showLoadingIndicator() { // 实现加载动画或进度条 } private void hideLoadingIndicator() { // 隐藏加载指示器 } @Override public void onAttach(@NonNull Context context) { super.onAttach(context); if (context instanceof MainActivity) { ((MainActivity) context).setOnDataLoadListener(new MainActivity.OnDataLoadListener() { @Override public void onDataLoaded() { // 数据加载完成后填充表格 getActivity().runOnUiThread(() -> { Log.d("DataLoad", "Data loaded, filling table"); fillTableData(); }); } @Override public void onDataError() { //showToast("数据加载失败"); } }); } } /** * 获取数据 */ private void fillTableData() { List<Dingdan> orders = Data.dingdans; List<Dingdan_Chanpin> orderProducts = Data.dingdan_chanpins; List<Dingdan_chanpin_zujian> orderMaterials = Data.Dingdan_chanpin_zujians; allTableRowsData.clear(); filteredTableRowsData.clear(); // 创建映射关系提高效率 Map<Integer, List<Dingdan_Chanpin>> orderProductMap = new HashMap<>(); Map<Integer, List<Chanpin_Zujian>> productComponentMap = new HashMap<>(); Map<Integer, List<Dingdan_chanpin_zujian>> componentMaterialMap = new HashMap<>(); // 构建映射 for (Dingdan_Chanpin op : orderProducts) { if (op != null && op.getDingdan() != null) { int orderId = op.getDingdan().getId(); orderProductMap.computeIfAbsent(orderId, k -> new ArrayList<>()).add(op); } } for (Chanpin_Zujian cz : Data.chanpin_zujians) { int productId = cz.getChanpin().getId(); productComponentMap.computeIfAbsent(productId, k -> new ArrayList<>()).add(cz); } for (Dingdan_chanpin_zujian dm : orderMaterials) { int componentId = dm.getZujian().getId(); componentMaterialMap.computeIfAbsent(componentId, k -> new ArrayList<>()).add(dm); } // 重组数据 for (Dingdan order : orders) { List<Dingdan_Chanpin> productsForOrder = orderProductMap.get(order.getId()); if (productsForOrder != null) { for (Dingdan_Chanpin op : productsForOrder) { Chanpin product = op.getChanpin(); List<Chanpin_Zujian> componentsForProduct = productComponentMap.get(product.getId()); if (componentsForProduct != null) { for (Chanpin_Zujian cz : componentsForProduct) { List<Dingdan_chanpin_zujian> materialsForComponent = componentMaterialMap.get(cz.getZujian().getId()); if (materialsForComponent != null) { for (Dingdan_chanpin_zujian dm : materialsForComponent) { Object[] rowData = createRowData(order, product, op, cz, dm); allTableRowsData.add(rowData); filteredTableRowsData.add(rowData); } } } } } } } // 日志记录添加行数 Log.d("TableFill", "Total rows created: " + allTableRowsData.size()); // 初始排序 sortTableData(-1, true); } /** * 排序表格数据并刷新显示 * * @param columnIndex 要排序的列索引 * @param ascending 是否升序排列 */ private void sortTableData(int columnIndex, boolean ascending) { // 更新排序状态 if (columnIndex >= 0) { if (currentSortColumn == columnIndex) { // 相同列点击时切换排序方向 sortAscending = !ascending; } else { currentSortColumn = columnIndex; sortAscending = true; // 新列默认升序 } } // 创建排序比较器 Comparator<Object[]> comparator = (row1, row2) -> { if (currentSortColumn < 0) { return 0; // 返回0表示相等,保持原顺序 } Object value1 = row1[currentSortColumn]; Object value2 = row2[currentSortColumn]; if (value1 == null && value2 == null) return 0; if (value1 == null) return -1; if (value2 == null) return 1; // 根据不同列数据类型定制比较规则 try { // 数值列:2(数量), 5(板材/组件), 6(订购数量) if (currentSortColumn == 2 || currentSortColumn == 5 || currentSortColumn == 6) { double d1 = Double.parseDouble(value1.toString()); double d2 = Double.parseDouble(value2.toString()); return sortAscending ? Double.compare(d1, d2) : Double.compare(d2, d1); } // 其他列按字符串排序 else { String s1 = value1.toString().toLowerCase(); String s2 = value2.toString().toLowerCase(); return sortAscending ? s1.compareTo(s2) : s2.compareTo(s1); } } catch (NumberFormatException e) { // 解析失败时按字符串比较 String s1 = value1.toString().toLowerCase(); String s2 = value2.toString().toLowerCase(); return sortAscending ? s1.compareTo(s2) : s2.compareTo(s1); } }; // 特殊处理初始未排序状态 if (columnIndex == -1) { // 直接复制数据而不排序 filteredTableRowsData.clear(); filteredTableRowsData.addAll(allTableRowsData); } else { Collections.sort(filteredTableRowsData, comparator); } // 刷新显示 refreshTableWithData(filteredTableRowsData); } /** * 表格数据动态添加 * * @param rowData */ private void addTableRow(Object[] rowData) { TableRow row = new TableRow(requireContext()); TableLayout.LayoutParams rowParams = new TableLayout.LayoutParams( TableLayout.LayoutParams.MATCH_PARENT, TableLayout.LayoutParams.WRAP_CONTENT ); row.setLayoutParams(rowParams); row.setMinimumHeight(dpToPx(36)); for (int i = 0; i < rowData.length; i++) { final Object data = rowData[i]; // 判断是否为操作列(最后一列) if (i == rowData.length - 1) { // 创建操作按钮 Button actionButton = new Button(requireContext()); actionButton.setText("操作"); actionButton.setTextSize(TypedValue.COMPLEX_UNIT_SP, 14); actionButton.setBackgroundResource(R.drawable.btn_selector); // 自定义按钮样式 // 设置按钮点击监听器 actionButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { handleRowAction(rowData, v); } }); // 设置按钮布局参数 TableRow.LayoutParams btnParams = new TableRow.LayoutParams( 0, // 宽度由权重控制 TableRow.LayoutParams.WRAP_CONTENT, 0.5f ); btnParams.weight = 0.5f; int margin = dpToPx(1); btnParams.setMargins(margin, margin, margin, margin); actionButton.setLayoutParams(btnParams); actionButton.setHeight(11); row.addView(actionButton); } else { // 正常文本列的代码(保持原逻辑) HorizontalScrollTextView textView = new HorizontalScrollTextView(requireContext()); textView.setText(String.valueOf(data)); textView.setTextSize(TypedValue.COMPLEX_UNIT_SP, 14); int padding = dpToPx(8); textView.setPadding(padding, padding / 2, padding, padding); textView.setMinWidth(dpToPx(50)); TableRow.LayoutParams colParams = null; // 设置背景边框 textView.setBackgroundResource(R.drawable.cell_border); if (data.toString().length() > 10) { colParams = new TableRow.LayoutParams( 0, // 宽度将由权重控制 TableRow.LayoutParams.MATCH_PARENT, 2.0f ); colParams.weight = 2; } else { colParams = new TableRow.LayoutParams( 0, // 宽度将由权重控制 TableRow.LayoutParams.MATCH_PARENT, 1.0f ); colParams.weight = 1; } textView.setLayoutParams(colParams); row.addView(textView); } } table.addView(row); } // 动态添加表头 (使用自定义TextView) private void addTableHeader(TableLayout table) { TableRow headerRow = new TableRow(requireContext()); headerRow.setLayoutParams(new TableLayout.LayoutParams( TableLayout.LayoutParams.MATCH_PARENT, TableLayout.LayoutParams.WRAP_CONTENT )); // 设置行背景颜色 headerRow.setBackgroundColor(ContextCompat.getColor(requireContext(), R.color.purple_500)); // 定义表头 // 更新表头数组(添加操作列) String[] headers = getResources().getStringArray(R.array.table_headers); List<String> headerList = new ArrayList<>(Arrays.asList(headers)); headerList.add("操作"); // 添加操作列标题 headers = headerList.toArray(new String[0]); // 更新权重数组(添加操作列权重) float[] weights = {1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 1.0f, 1.0f, 0.5f}; // 新增操作列权重0.5 // 更新优先级数组(添加操作列优先级) boolean[] priority = {false, false, false, false, true, false, false, false}; for (int i = 0; i < headers.length; i++) { HorizontalScrollTextView headerView = new HorizontalScrollTextView(requireContext()); headerView.setText(headers[i]); headerView.setTextColor(Color.WHITE); headerView.setTypeface(null, Typeface.BOLD); headerView.setTextSize(TypedValue.COMPLEX_UNIT_SP, 16); headerView.setPadding(dpToPx(8), dpToPx(8), dpToPx(8), dpToPx(8)); // 为优先级高的列设置最小宽度 if (priority[i]) { headerView.setMinWidth(dpToPx(220)); } // 设置布局参数 TableRow.LayoutParams colParams = new TableRow.LayoutParams( priority[i] ? TableRow.LayoutParams.WRAP_CONTENT : 0, TableRow.LayoutParams.MATCH_PARENT, priority[i] ? 0 : weights[i] // 优先级列不使用权重 ); headerView.setLayoutParams(colParams); final int columnIndex = i; headerView.setOnClickListener(v -> { // 排序并刷新表格 sortTableData(columnIndex, sortAscending); // 更新排序指示器(可选) showSortIndicator(headerView); }); headerRow.addView(headerView); } table.addView(headerRow); } // 添加排序指示器(可选) private void showSortIndicator(View header) { // 实现:在表头右侧添加↑或↓指示符 // 实现逻辑根据设计需求 // header.setTooltipText(new ); } /** * */ private void showScrollIndicator() { isIndicatorVisible = true; View indicator = getView().findViewById(R.id.scroll_indicator); if (scrollIndicatorAnimator != null && scrollIndicatorAnimator.isRunning()) { scrollIndicatorAnimator.cancel(); } indicator.setVisibility(View.VISIBLE); indicator.setAlpha(0f); scrollIndicatorAnimator = ObjectAnimator.ofFloat(indicator, "alpha", 0f, 0.8f); scrollIndicatorAnimator.setDuration(300); scrollIndicatorAnimator.start(); } /** * + */ private void hideScrollIndicator() { isIndicatorVisible = false; View indicator = getView().findViewById(R.id.scroll_indicator); if (scrollIndicatorAnimator != null && scrollIndicatorAnimator.isRunning()) { scrollIndicatorAnimator.cancel(); } scrollIndicatorAnimator = ObjectAnimator.ofFloat(indicator, "alpha", indicator.getAlpha(), 0f); scrollIndicatorAnimator.setDuration(300); scrollIndicatorAnimator.addListener(new AnimatorListenerAdapter() { @Override public void onAnimationEnd(Animator animation) { indicator.setVisibility(View.INVISIBLE); } }); scrollIndicatorAnimator.start(); } /** * @param currentScroll * @param maxScroll */ private void updateScrollIndicatorPosition(int currentScroll, int maxScroll) { View indicator = getView().findViewById(R.id.scroll_indicator); FrameLayout.LayoutParams params = (FrameLayout.LayoutParams) indicator.getLayoutParams(); // 计算指示器位置(0-100%) float percentage = (float) currentScroll / maxScroll; int maxMargin = getResources().getDisplayMetrics().widthPixels - indicator.getWidth(); // 设置右边距(控制位置) params.rightMargin = (int) (maxMargin * percentage); indicator.setLayoutParams(params); } // 处理行操作的方法 private void handleRowAction(Object[] rowData, View anchorButton) { // 安全地从行数据中提取关键信息 String orderNumber = safeGetString(rowData[0]); // 订单号 String productId = safeGetString(rowData[1]); // 产品ID String componentName = safeGetString(rowData[3]); // 组件名称 // 安全地获取订购数量 double materialQuantity = 0.0; try { if (rowData[6] != null) { if (rowData[6] instanceof Number) { materialQuantity = ((Number) rowData[6]).doubleValue(); } else { materialQuantity = Double.parseDouble(rowData[6].toString()); } } } catch (Exception e) { Log.e("OrderFragment", "Failed to parse material quantity", e); } Context context = getContext(); if (context == null || anchorButton == null) { Log.w("PopupMenu", "Context or anchorButton is null"); return; } PopupMenu popupMenu = new PopupMenu(context, anchorButton); // 强制设置菜单在锚点视图下方显示(关键设置) popupMenu.setGravity(Gravity.BOTTOM); // 如果使用支持库,设置弹出方向 // 设置在锚点视图下方显示 // popupMenu.setOverlapAnchor(true); // 填充菜单项 popupMenu.getMenuInflater().inflate(R.menu.row_actions_menu, popupMenu.getMenu()); // 设置菜单项点击监听器 popupMenu.setOnMenuItemClickListener(item -> { int itemId = item.getItemId(); if (itemId == R.id.action_view_details) { showDetailDialog(orderNumber, productId); return true; } else if (itemId == R.id.action_edit) { editRowData(rowData); return true; } else if (itemId == R.id.action_delete) { deleteRowWithConfirm(rowData); return true; } return false; }); popupMenu.show(); } // 安全获取字符串值的方法 private String safeGetString(Object value) { if (value == null) return ""; if (value instanceof String) return (String) value; return value.toString(); } // 查看详情对话框 private void showDetailDialog(String orderNumber, String productId) { AlertDialog.Builder builder = new AlertDialog.Builder(requireContext()); builder.setTitle("订单详情") .setMessage("订单号: " + orderNumber + "\n产品ID: " + productId) .setPositiveButton("确定", null) .show(); } // 编辑行数据 private void editRowData(Object[] rowData) { // 实现编辑逻辑 // 这里创建包含表单的对话框 Toast.makeText(requireContext(), "编辑操作: " + rowData[0], Toast.LENGTH_SHORT).show(); } // 带确认的删除操作 private void deleteRowWithConfirm(Object[] rowData) { new AlertDialog.Builder(requireContext()) .setTitle("确认删除") .setMessage("确定要删除订单 " + rowData[0] + " 吗?") .setPositiveButton("删除", (dialog, which) -> { // 实际删除逻辑 deleteRow(rowData); }) .setNegativeButton("取消", null) .show(); } // 实际删除行数据 private void deleteRow(Object[] rowData) { // 1. 从allTableRowsData中移除对应行 for (Iterator<Object[]> iterator = allTableRowsData.iterator(); iterator.hasNext(); ) { Object[] row = iterator.next(); if (Arrays.equals(row, rowData)) { iterator.remove(); break; } } // 2. 从filteredTableRowsData中移除 filteredTableRowsData.removeIf(row -> Arrays.equals(row, rowData)); // 3. 刷新表格 refreshTableWithData(filteredTableRowsData); Toast.makeText(requireContext(), "已删除订单", Toast.LENGTH_SHORT).show(); } // DP转PX工具方法 private int dpToPx(int dp) { return (int) TypedValue.applyDimension( TypedValue.COMPLEX_UNIT_DIP, dp, getResources().getDisplayMetrics() ); } /** * 数据组合 * * @param order * @param product * @param component * @param material * @return */ private Object[] createRowData(Dingdan order, Chanpin product, Dingdan_Chanpin dingdan_chanpin, Chanpin_Zujian component, Dingdan_chanpin_zujian material) { Bancai board = material.getBancai(); String boardInfo = board.TableText(); ; return new Object[]{ order.getNumber(), // 订单号 product.getBianhao(), // 产品编号 dingdan_chanpin.getShuliang(), // 产品数量 (根据需求调整) component.getZujian().getName(), // 组件名 boardInfo, // 板材信息 Math.round(component.getOne_howmany()), // 板材/组件 material.getShuliang(), // 订购数量 "操作" }; } // 初始化列选择器 private void initColumnSelector() { ArrayAdapter<CharSequence> adapter = ArrayAdapter.createFromResource( requireContext(), R.array.table_headers, android.R.layout.simple_spinner_item ); adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item); columnSelector.setAdapter(adapter); // 添加"所有列"选项 columnSelector.setSelection(0); // 默认选择第一个选项(所有列) // 列选择变化监听 columnSelector.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() { @Override public void onItemSelected(AdapterView<?> parent, View view, int position, long id) { applySearchFilter(); } @Override public void onNothingSelected(AdapterView<?> parent) { } }); } // 应用搜索过滤 private void applySearchFilter() { String query = searchView.getQuery().toString().trim().toLowerCase(); int selectedColumn = columnSelector.getSelectedItemPosition(); filteredTableRowsData.clear(); if (query.isEmpty()) { // 没有搜索词,显示所有数据 filteredTableRowsData.addAll(allTableRowsData); } else { // 根据选择的列进行过滤 for (Object[] row : allTableRowsData) { // 如果选择"所有列"(位置0),检查所有列 if (selectedColumn == 0) { for (Object cell : row) { if (cell != null && cell.toString().toLowerCase().contains(query)) { filteredTableRowsData.add(row); break; } } } // 检查特定列 else if (selectedColumn >= 1 && selectedColumn <= row.length) { int columnIndex = selectedColumn - 1; // 调整索引(0=所有列,1=第一列) if (row[columnIndex] != null && row[columnIndex].toString().toLowerCase().contains(query)) { filteredTableRowsData.add(row); } } } } // 刷新表格显示 refreshTableWithData(filteredTableRowsData); } /** * 刷新表格显示 */ private void refreshTableWithData(Iterable<? extends Object[]> dataToShow) { // Log.d("TableRefresh", "Refreshing table with " + currentSortColumn + " rows"); // 添加调试信息 Log.d("TableRefresh", "Refreshing table with " + currentSortColumn + " rows"); removeAllRowsSafely(); int addedRows = 0; for (Object[] rowData : dataToShow) { addTableRow(rowData); addedRows++; } // 添加空数据提示 if (addedRows == 0) { addEmptyTableRow(); } } private void addEmptyTableRow() { TableRow row = new TableRow(requireContext()); TextView emptyView = new TextView(requireContext()); emptyView.setText("暂无数据"); emptyView.setGravity(Gravity.CENTER); emptyView.setLayoutParams(new TableRow.LayoutParams( TableRow.LayoutParams.MATCH_PARENT, TableRow.LayoutParams.WRAP_CONTENT )); row.addView(emptyView); table.addView(row); } private void removeAllRowsSafely() { // 移除除表头外的所有行(如果有表头) if (table.getChildCount() > 0) { // 保留表头(索引0) for (int i = table.getChildCount() - 1; i >= 1; i--) { View child = table.getChildAt(i); table.removeView(child); // 清理视图引用(非常重要!) cleanupRowViews((TableRow) child); } } } private void cleanupRowViews(TableRow row) { int childCount = row.getChildCount(); for (int i = 0; i < childCount; i++) { View view = row.getChildAt(i); // 解除视图的所有监听器 view.setOnClickListener(null); // 特别是操作按钮,需要取消所有监听器 if (view instanceof Button) { Button button = (Button) view; button.setOnClickListener(null); // 清空按钮的数据引用 button.setTag(null); } } // 从父视图中移除行 if (row.getParent() != null) { ((ViewGroup) row.getParent()).removeView(row); } } } E FATAL EXCEPTION: main Process: com.example.kucun2, PID: 2475 java.lang.NullPointerException: Attempt to invoke virtual method 'android.view.View android.view.View.findViewById(int)' on a null object reference at com.example.kucun2.ui.dingdan.OrderDisplayFragment.hideScrollIndicator(OrderDisplayFragment.java:553) at com.example.kucun2.ui.dingdan.OrderDisplayFragment.lambda$onCreateView$0(OrderDisplayFragment.java:199) at com.example.kucun2.ui.dingdan.OrderDisplayFragment.$r8$lambda$VEviHEc8N_lTbpxv9eW9n-N0BE8(Unknown Source:0) at com.example.kucun2.ui.dingdan.OrderDisplayFragment$$ExternalSyntheticLambda8.onScrollChanged(D8$$SyntheticClass:0) at android.view.ViewTreeObserver.dispatchOnScrollChanged(ViewTreeObserver.java:1301) at android.view.ViewRootImpl.draw(ViewRootImpl.java:6618) at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:6390) at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:5268) at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:3667) at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:12113) at android.view.Choreographer$CallbackRecord.run(Choreographer.java:2459) at android.view.Choreographer$CallbackRecord.run(Choreographer.java:2468) at android.view.Choreographer.doCallbacks(Choreographer.java:1693) at android.view.Choreographer.doFrame(Choreographer.java:1448) at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:2284) at android.os.Handler.handleCallback(Handler.java:1014) at android.os.Handler.dispatchMessage(Handler.java:102) at android.os.Looper.loopOnce(Looper.java:250) at android.os.Looper.loop(Looper.java:340) at android.app.ActivityThread.main(ActivityThread.java:9913) at java.lang.reflect.Method.invoke(Native Method) at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:621) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:957)
06-14
import re from typing import Optional, Dict, List, Tuple, AsyncGenerator from datetime import date, datetime, timedelta, timezone from fastapi import Depends from fastapi.concurrency import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine,AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import text # -------------------------- 数据库核心配置 -------------------------- # 建议从 .env 加载(原配置硬编码,此处优化为环境变量读取) import os from dotenv import load_dotenv load_dotenv() DATABASE_URL = os.getenv( "DATABASE_URL", "mysql+asyncmy://root:123456@localhost/ai_roleplay?charset=utf8mb4" # 默认 fallback ) # 异步引擎配置(优化连接池参数) engine = create_async_engine( DATABASE_URL, echo=False, # 生产环境设为 False,避免日志冗余 pool_pre_ping=True, # 连接前校验,防止失效连接 pool_size=10, # 常驻连接数 max_overflow=20, # 最大临时连接数 pool_recycle=3600 # 连接超时回收(1小时) ) # 异步 Session 工厂(线程安全) AsyncSessionLocal = sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False, # 提交后不失效对象 autoflush=False # 关闭自动刷新,减少不必要 SQL ) # -------------------------- 通用工具函数 -------------------------- def get_conversation_table_name(user_id: str) -> str: """生成用户专属对话表名(防 SQL 注入)""" safe_id = "".join(c for c in str(user_id) if c.isalnum() or c == "_") return f"conversations_{safe_id}" def is_valid_table_name(table_name: str) -> bool: """校验表名合法性(仅允许 conversations_xxx 格式)""" return re.match(r'^conversations_[a-zA-Z0-9_]+$', table_name) is not None @asynccontextmanager async def get_default_db() -> AsyncGenerator[AsyncSession,None]: """ 自动创建默认数据库会话(上下文管理器,自动管理生命周期) 用于数据库函数的默认db参数,避免手动传参 """ async with AsyncSessionLocal() as db: # 基于全局工厂创建独立会话 try: yield db # 提供会话给函数使用 await db.commit() # 函数无异常则提交 except Exception as e: await db.rollback() # 异常则回滚 raise e # 重新抛出异常,让上层处理 finally: await db.close() # 无论成败都关闭会话 async def get_default_db_instance() -> AsyncSession: """ 获取默认db实例(供函数默认参数使用) 本质是触发 get_default_db() 上下文管理器,返回会话对象 """ return await anext(get_default_db()) # anext() 用于异步上下文管理器 # -------------------------- 用户基础操作 -------------------------- async def get_user_by_account(account: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过账号查询用户(用于注册时判重、登录校验)""" result = await db.execute( text(""" SELECT id, account, password, role, department_id, created_at FROM users WHERE account = :account """), {"account": account} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_user_by_id(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过用户ID查询用户(用于角色修改、权限校验)""" result = await db.execute( text(""" SELECT id, account, role, department_id, created_at FROM users WHERE id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_user_detail(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """获取用户详情(含院系名称,用于个人中心)""" result = await db.execute( text(""" SELECT u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name FROM users u LEFT JOIN departments d ON u.department_id = d.id WHERE u.id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 用户创建与更新 -------------------------- async def create_user( account: str, password: str, role: str = "user", department_id: Optional[int] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建用户(注册专用),并自动创建专属对话表""" # 1. 插入用户记录 result = await db.execute( text(""" INSERT INTO users (account, password, role, department_id, created_at) VALUES (:account, :password, :role, :dept_id, NOW()) """), { "account": account, "password": password, "role": role, "dept_id": department_id } ) user_id = result.lastrowid # 获取自增ID # 2. 创建用户专属对话表(关联 AI 角色) table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): raise ValueError(f"Invalid user ID for conversation table: {user_id}") await db.execute(text(f""" CREATE TABLE IF NOT EXISTS `{table_name}` ( id INT AUTO_INCREMENT PRIMARY KEY, character_id INT NOT NULL, # 关联 AI 角色表 user_message TEXT NOT NULL, # 用户消息 ai_message TEXT NOT NULL, # AI 回复 timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE ) ENGINE=InnoDB CHARSET=utf8mb4; """)) # 3. 返回创建的用户信息 return await get_user_by_id(db, user_id) async def update_user( user_id: str, update_params: Dict, # 支持更新:password、role、department_id db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新用户信息(动态拼接 SQL,避免冗余)""" if not update_params: return # 无参数则不执行 # 动态生成更新字段(防注入:仅允许指定字段) allowed_fields = ["password", "role", "department_id"] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return # 补充用户ID参数 params = {**update_params, "user_id": user_id} await db.execute( text(f"UPDATE users SET {set_clause} WHERE id = :user_id"), params ) # -------------------------- 用户列表与统计 -------------------------- async def get_users_list( page: int = 1, size: int = 10, role: Optional[str] = None, dept_id: Optional[int] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询用户列表(管理员专用,支持角色/院系筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if role: where_clause.append("role = :role") params["role"] = role if dept_id is not None: where_clause.append("department_id = :dept_id") params["dept_id"] = dept_id where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 查询总数(用于分页) total_result = await db.execute( text(f"SELECT COUNT(*) AS total FROM users {where_sql}"), params ) total = total_result.scalar() # 3. 查询分页数据 data_result = await db.execute( text(f""" SELECT id, account, role, department_id, created_at FROM users {where_sql} ORDER BY created_at DESC LIMIT :offset, :limit """), params ) users = [dict(row._mapping) for row in data_result.fetchall()] return total, users async def get_user_count_by_dept(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> int: """统计指定院系的用户数(删除院系前校验)""" result = await db.execute( text("SELECT COUNT(*) FROM users WHERE department_id = :dept_id"), {"dept_id": dept_id} ) return result.scalar() # -------------------------- 原登录校验修正 -------------------------- async def check_users(account: str, password: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Tuple[str, str]]: """仅校验用户账号密码(原逻辑拆分,插入用户移至 create_user)""" result = await db.execute( text("SELECT id, password FROM users WHERE account = :account"), {"account": account} ) row = result.fetchone() return (str(row.id), row.password) if row else None # -------------------------- 院系基础操作 -------------------------- async def create_department( name: str, description: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建院系(管理员专用)""" result = await db.execute( text(""" INSERT INTO departments (name, description, created_at) VALUES (:name, :desc, NOW()) """), {"name": name, "desc": description} ) dept_id = result.lastrowid return await get_department_by_id(db, dept_id) async def get_department_by_id(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询院系""" result = await db.execute( text("SELECT id, name, description, created_at FROM departments WHERE id = :dept_id"), {"dept_id": dept_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_department_by_name(name: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过名称查询院系(创建时判重)""" result = await db.execute( text("SELECT id, name FROM departments WHERE name = :name"), {"name": name} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 院系列表与统计 -------------------------- async def get_departments_with_user_count(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取所有院系(含用户数统计)""" result = await db.execute( text(""" SELECT d.id, d.name, d.description, d.created_at, COUNT(u.id) AS user_count FROM departments d LEFT JOIN users u ON d.id = u.department_id GROUP BY d.id ORDER BY d.created_at DESC """) ) return [dict(row._mapping) for row in result.fetchall()] # -------------------------- 院系更新与删除 -------------------------- async def update_department( dept_id: int, update_params: Dict, # 支持更新:name、description db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新院系信息(管理员专用)""" allowed_fields = ["name", "description"] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return params = {**update_params, "dept_id": dept_id} await db.execute( text(f"UPDATE departments SET {set_clause} WHERE id = :dept_id"), params ) async def delete_department(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除院系(需先确保无用户关联)""" await db.execute( text("DELETE FROM departments WHERE id = :dept_id"), {"dept_id": dept_id} ) # -------------------------- 院系专属资源查询 -------------------------- async def get_dept_exclusive_rooms(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取院系专属聊天室(仅本院系用户可见)""" result = await db.execute( text(""" SELECT id, name, description, creator_id, created_at FROM rooms WHERE type = 'dept' AND dept_id = :dept_id ORDER BY created_at DESC """), {"dept_id": dept_id} ) return [dict(row._mapping) for row in result.fetchall()] async def get_dept_exclusive_shares( dept_id: int, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """获取院系专属分享(分页,仅本院系用户可见)""" # 1. 统计总数 total_result = await db.execute( text(""" SELECT COUNT(*) AS total FROM shares s JOIN users u ON s.author_id = u.id WHERE s.type = 'dept' AND u.department_id = :dept_id """), {"dept_id": dept_id} ) total = total_result.scalar() # 2. 查询分页数据 data_result = await db.execute( text(""" SELECT s.*, u.account AS author_account FROM shares s JOIN users u ON s.author_id = u.id WHERE s.type = 'dept' AND u.department_id = :dept_id ORDER BY s.created_at DESC LIMIT :offset, :limit """), { "dept_id": dept_id, "offset": (page - 1) * size, "limit": size } ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares # -------------------------- AI角色操作 -------------------------- async def get_all_characters(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取所有AI角色(用于聊天页面角色选择)""" result = await db.execute( text("SELECT id, name, trait, avatar_url FROM characters ORDER BY name ASC") ) return [dict(row._mapping) for row in result.fetchall()] async def get_character_by_id(character_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询AI角色(聊天时获取角色设定)""" result = await db.execute( text("SELECT id, name, trait FROM characters WHERE id = :character_id"), {"character_id": character_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 对话历史操作 -------------------------- async def save_conversation( user_id: int, character_id: int, user_message: str, ai_message: str, db: AsyncSession = Depends(get_default_db_instance) ) -> None: """保存用户与AI的对话(聊天后存储)""" table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): raise ValueError(f"Invalid user ID: {user_id}") await db.execute( text(f""" INSERT INTO `{table_name}` (character_id, user_message, ai_message) VALUES (:char_id, :user_msg, :ai_msg) """), { "char_id": character_id, "user_msg": user_message, "ai_msg": ai_message } ) async def load_conversation_history( user_id: str, character_id: Optional[int] = None, max_count: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: """加载用户对话历史(支持按AI角色筛选)""" table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): return [] # 表名无效则返回空历史 # 构建筛选条件(可选按角色筛选) where_clause = "WHERE character_id = :char_id" if character_id else "" params = {"limit": max_count} if character_id: params["char_id"] = character_id # 查询最近的 max_count 条历史(时间正序) result = await db.execute( text(f""" SELECT user_message, ai_message, timestamp FROM `{table_name}` {where_clause} ORDER BY timestamp DESC LIMIT :limit """), params ) rows = result.fetchall() # 转换为 [{user: ..., ai: ...}, ...] 格式,按时间正序排列 history = [ { "user": row.user_message, "ai": row.ai_message, "time": row.timestamp.strftime("%Y-%m-%d %H:%M:%S") } for row in reversed(rows) # 反转后变为时间正序 ] return history # -------------------------- 用户个性化设定 -------------------------- async def get_user_profile(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """获取用户个性化设定(如自定义角色配置)""" result = await db.execute( text(""" SELECT personality, role_setting FROM user_profiles WHERE user_id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def create_or_update_user_profile( user_id: str, personality: str, role_setting: str, db: AsyncSession = Depends(get_default_db_instance) ) -> bool: """创建/更新用户个性化设定(存在则更新,不存在则创建)""" await db.execute( text(""" INSERT INTO user_profiles (user_id, personality, role_setting, updated_at) VALUES (:user_id, :personality, :role_setting, NOW()) ON DUPLICATE KEY UPDATE personality = VALUES(personality), role_setting = VALUES(role_setting), updated_at = NOW() """), { "user_id": user_id, "personality": personality.strip(), "role_setting": role_setting.strip() } ) return True # -------------------------- 聊天室基础操作 -------------------------- async def create_room( name: str, type: str, # 类型:public(公开)、dept(院系)、ai(AI专属) creator_id: str, dept_id: Optional[int] = None, ai_character_id: Optional[int] = None, description: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建聊天室(支持三种类型)""" result = await db.execute( text(""" INSERT INTO rooms ( name, type, dept_id, ai_character_id, description, creator_id, created_at ) VALUES (:name, :type, :dept_id, :ai_char_id, :desc, :creator_id, NOW()) """), { "name": name, "type": type, "dept_id": dept_id, "ai_char_id": ai_character_id, "desc": description, "creator_id": creator_id } ) room_id = result.lastrowid return await get_room_by_id(db, room_id) async def get_room_by_id(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询聊天室详情""" result = await db.execute( text(""" SELECT r.*, u.account AS creator_account, c.name AS ai_char_name # 关联AI角色名称(若为AI专属) FROM rooms r JOIN users u ON r.creator_id = u.id LEFT JOIN characters c ON r.ai_character_id = c.id WHERE r.id = :room_id """), {"room_id": room_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 聊天室列表查询 -------------------------- async def get_rooms( type: Optional[str] = None, dept_id: Optional[int] = None, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询聊天室(支持按类型、院系筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if type: where_clause.append("r.type = :type") params["type"] = type if dept_id is not None: where_clause.append("r.dept_id = :dept_id") params["dept_id"] = dept_id where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 统计总数 total_result = await db.execute( text(f""" SELECT COUNT(*) AS total FROM rooms r {where_sql} """), params ) total = total_result.scalar() # 3. 查询分页数据 data_result = await db.execute( text(f""" SELECT r.id, r.name, r.type, r.description, r.creator_id, u.account AS creator_account, r.created_at, COUNT(rm.user_id) AS member_count FROM rooms r JOIN users u ON r.creator_id = u.id LEFT JOIN room_members rm ON r.id = rm.room_id {where_sql} GROUP BY r.id ORDER BY r.created_at DESC LIMIT :offset, :limit """), params ) rooms = [dict(row._mapping) for row in data_result.fetchall()] return total, rooms # -------------------------- 聊天室成员管理 -------------------------- async def check_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool: """校验用户是否为聊天室成员""" result = await db.execute( text(""" SELECT 1 FROM room_members WHERE room_id = :room_id AND user_id = :user_id """), {"room_id": room_id, "user_id": user_id} ) return result.scalar() is not None async def add_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """添加用户到聊天室成员(加入聊天室)""" await db.execute( text(""" INSERT IGNORE INTO room_members (room_id, user_id, joined_at) VALUES (:room_id, :user_id, NOW()) """), {"room_id": room_id, "user_id": user_id} ) async def remove_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """从聊天室移除用户(离开聊天室)""" await db.execute( text(""" DELETE FROM room_members WHERE room_id = :room_id AND user_id = :user_id """), {"room_id": room_id, "user_id": user_id} ) # -------------------------- 聊天室消息操作 -------------------------- async def create_room_message( room_id: int, sender_id: str, content: str, sent_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发送聊天室消息(成员专用)""" sent_at = sent_at or datetime.now() result = await db.execute( text(""" INSERT INTO room_messages ( room_id, sender_id, content, sent_at ) VALUES (:room_id, :sender_id, :content, :sent_at) """), { "room_id": room_id, "sender_id": sender_id, "content": content, "sent_at": sent_at } ) msg_id = result.lastrowid # 返回消息详情(含发送者账号) msg_result = await db.execute( text(""" SELECT rm.id, rm.content, rm.sent_at, u.account AS sender_account FROM room_messages rm JOIN users u ON rm.sender_id = u.id WHERE rm.id = :msg_id """), {"msg_id": msg_id} ) return dict(msg_result.fetchone()._mapping) async def get_room_messages( room_id: int, page: int = 1, size: int = 20, order_by: str = "sent_at DESC", db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页获取聊天室历史消息(支持排序)""" # 1. 统计总数 total_result = await db.execute( text("SELECT COUNT(*) AS total FROM room_messages WHERE room_id = :room_id"), {"room_id": room_id} ) total = total_result.scalar() # 2. 查询分页数据(防排序注入:仅允许指定排序字段) valid_order = ["sent_at ASC", "sent_at DESC"] order_sql = order_by if order_by in valid_order else "sent_at DESC" data_result = await db.execute( text(f""" SELECT rm.id, rm.content, rm.sent_at, u.id AS sender_id, u.account AS sender_account FROM room_messages rm JOIN users u ON rm.sender_id = u.id WHERE rm.room_id = :room_id ORDER BY {order_sql} LIMIT :offset, :limit """), { "room_id": room_id, "offset": (page - 1) * size, "limit": size } ) messages = [dict(row._mapping) for row in data_result.fetchall()] return total, messages # -------------------------- 聊天室删除 -------------------------- async def delete_room(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除聊天室(级联删除成员和消息)""" # 1. 删除成员关联 await db.execute( text("DELETE FROM room_members WHERE room_id = :room_id"), {"room_id": room_id} ) # 2. 删除消息 await db.execute( text("DELETE FROM room_messages WHERE room_id = :room_id"), {"room_id": room_id} ) # 3. 删除聊天室本身 await db.execute( text("DELETE FROM rooms WHERE id = :room_id"), {"room_id": room_id} ) # -------------------------- 分享基础操作 -------------------------- async def create_share( title: str, content: str, author_id: str, is_public: bool = True, type: str = "public", # 类型:public(公开)、private(私有)、dept(院系) ai_character_id: Optional[int] = None, created_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发布分享(支持三种类型)""" created_at = created_at or datetime.now() result = await db.execute( text(""" INSERT INTO shares ( title, content, author_id, is_public, type, ai_character_id, view_count, like_count, comment_count, created_at ) VALUES ( :title, :content, :author_id, :is_public, :type, :ai_char_id, 0, 0, 0, :created_at ) """), { "title": title, "content": content, "author_id": author_id, "is_public": is_public, "type": type, "ai_char_id": ai_character_id, "created_at": created_at } ) share_id = result.lastrowid return await get_share_by_id(db, share_id) async def get_share_by_id(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询分享详情(含作者信息)""" result = await db.execute( text(""" SELECT s.*, u.account AS author_account, u.department_id, c.name AS ai_char_name # 关联AI角色名称 FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id WHERE s.id = :share_id """), {"share_id": share_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 分享列表查询 -------------------------- async def get_shares( is_public: Optional[bool] = None, author_id: Optional[str] = None, type: Optional[str] = None, order_by: str = "created_at DESC", page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询分享(支持按公开性、作者、类型筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if is_public is not None: where_clause.append("s.is_public = :is_public") params["is_public"] = is_public if author_id: where_clause.append("s.author_id = :author_id") params["author_id"] = author_id if type: where_clause.append("s.type = :type") params["type"] = type where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 统计总数 total_result = await db.execute( text(f"SELECT COUNT(*) AS total FROM shares s {where_sql}"), params ) total = total_result.scalar() # 3. 查询分页数据(防排序注入) valid_order = ["created_at DESC", "created_at ASC", "like_count DESC", "view_count DESC"] order_sql = order_by if order_by in valid_order else "created_at DESC" data_result = await db.execute( text(f""" SELECT s.*, u.account AS author_account, c.name AS ai_char_name FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id {where_sql} ORDER BY {order_sql} LIMIT :offset, :limit """), params ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares # -------------------------- 分享更新与删除 -------------------------- async def update_share( share_id: int, update_params: Dict, # 支持更新:title、content、is_public、type、ai_character_id、view_count等 db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新分享信息(作者专用)""" allowed_fields = [ "title", "content", "is_public", "type", "ai_character_id", "view_count", "like_count", "comment_count" ] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return params = {**update_params, "share_id": share_id} await db.execute( text(f"UPDATE shares SET {set_clause} WHERE id = :share_id"), params ) async def delete_share(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除分享(级联删除评论和点赞)""" # 1. 删除点赞关联 await db.execute( text("DELETE FROM share_likes WHERE share_id = :share_id"), {"share_id": share_id} ) # 2. 删除评论 await db.execute( text("DELETE FROM comments WHERE share_id = :share_id"), {"share_id": share_id} ) # 3. 删除分享本身 await db.execute( text("DELETE FROM shares WHERE id = :share_id"), {"share_id": share_id} ) # -------------------------- 评论操作 -------------------------- async def create_comment( share_id: int, commenter_id: str, content: str, parent_id: Optional[int] = None, created_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发表评论(支持回复父评论)""" created_at = created_at or datetime.now() result = await db.execute( text(""" INSERT INTO comments ( share_id, commenter_id, parent_id, content, created_at ) VALUES (:share_id, :commenter_id, :parent_id, :content, :created_at) """), { "share_id": share_id, "commenter_id": commenter_id, "parent_id": parent_id, "content": content, "created_at": created_at } ) comment_id = result.lastrowid # 返回评论详情(含评论者账号) comm_result = await db.execute( text(""" SELECT c.id, c.content, c.parent_id, c.created_at, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.id = :comment_id """), {"comment_id": comment_id} ) return dict(comm_result.fetchone()._mapping) async def get_share_comments( share_id: int, page: int = 1, size: int = 20, order_by: str = "created_at DESC", db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页获取分享的评论(含子评论层级)""" # 1. 统计总数 total_result = await db.execute( text("SELECT COUNT(*) AS total FROM comments WHERE share_id = :share_id"), {"share_id": share_id} ) total = total_result.scalar() # 2. 查询分页数据(先查父评论,再关联子评论) valid_order = ["created_at ASC", "created_at DESC"] order_sql = order_by if order_by in valid_order else "created_at DESC" # 第一步:查询父评论(parent_id IS NULL) parent_result = await db.execute( text(f""" SELECT c.id, c.content, c.created_at, u.id AS commenter_id, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.share_id = :share_id AND c.parent_id IS NULL ORDER BY {order_sql} LIMIT :offset, :limit """), { "share_id": share_id, "offset": (page - 1) * size, "limit": size } ) parent_comments = [dict(row._mapping) for row in parent_result.fetchall()] parent_ids = [comm["id"] for comm in parent_comments] # 第二步:查询所有子评论(parent_id 在父评论ID列表中) child_comments = [] if parent_ids: child_result = await db.execute( text(f""" SELECT c.id, c.content, c.parent_id, c.created_at, u.id AS commenter_id, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.share_id = :share_id AND c.parent_id IN :parent_ids ORDER BY {order_sql} """), { "share_id": share_id, "parent_ids": tuple(parent_ids) } ) child_comments = [dict(row._mapping) for row in child_result.fetchall()] # 第三步:构建父子评论层级 child_map = {} for child in child_comments: parent_id = child["parent_id"] if parent_id not in child_map: child_map[parent_id] = [] child_map[parent_id].append(child) # 给父评论添加子评论列表 for comm in parent_comments: comm["children"] = child_map.get(comm["id"], []) return total, parent_comments async def get_comment_by_id(comment_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询评论(校验父评论是否存在)""" result = await db.execute( text(""" SELECT id, share_id, commenter_id, parent_id FROM comments WHERE id = :comment_id """), {"comment_id": comment_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 点赞操作 -------------------------- async def check_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool: """校验用户是否已点赞该分享""" result = await db.execute( text(""" SELECT 1 FROM share_likes WHERE share_id = :share_id AND user_id = :user_id """), {"share_id": share_id, "user_id": user_id} ) return result.scalar() is not None async def add_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """给分享点赞""" await db.execute( text(""" INSERT IGNORE INTO share_likes (share_id, user_id, liked_at) VALUES (:share_id, :user_id, NOW()) """), {"share_id": share_id, "user_id": user_id} ) async def remove_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """取消分享点赞""" await db.execute( text(""" DELETE FROM share_likes WHERE share_id = :share_id AND user_id = :user_id """), {"share_id": share_id, "user_id": user_id} ) # -------------------------- 搜索记录操作 -------------------------- async def add_search_record( keyword: str, user_id: Optional[str] = None, search_time: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> None: """记录用户搜索行为(用于热搜统计)""" search_time = search_time or datetime.now() await db.execute( text(""" INSERT INTO search_records (keyword, user_id, search_time) VALUES (:keyword, :user_id, :search_time) """), { "keyword": keyword.strip(), "user_id": user_id, "search_time": search_time } ) async def get_hot_searches( date: Optional[date] = None, limit: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: """获取热搜词TOP(默认今日,按搜索次数排序)""" date = date or datetime.now().date() result = await db.execute( text(""" SELECT keyword, COUNT(*) AS search_count FROM search_records WHERE DATE(search_time) = :date GROUP BY keyword ORDER BY search_count DESC LIMIT :limit """), {"date": date, "limit": limit} ) return [dict(row._mapping) for row in result.fetchall()] async def search_shares( keyword: str, is_public: bool = True, author_id: Optional[str] = None, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """搜索分享(关键词匹配标题/内容)""" # 构建模糊查询参数 like_keyword = f"%{keyword}%" params = { "keyword": like_keyword, "is_public": is_public, "offset": (page - 1) * size, "limit": size } if author_id: params["author_id"] = author_id author_clause = "AND s.author_id = :author_id" else: author_clause = "" # 1. 统计总数 total_result = await db.execute( text(f""" SELECT COUNT(*) AS total FROM shares s WHERE s.is_public = :is_public AND (s.title LIKE :keyword OR s.content LIKE :keyword) {author_clause} """), params ) total = total_result.scalar() # 2. 查询分页数据 data_result = await db.execute( text(f""" SELECT s.*, u.account AS author_account, c.name AS ai_char_name FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id WHERE s.is_public = :is_public AND (s.title LIKE :keyword OR s.content LIKE :keyword) {author_clause} ORDER BY s.created_at DESC LIMIT :offset, :limit """), params ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares async def search_rooms( keyword: str, user_id: str, department_id: int, is_admin: bool = False, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: like_keyword = f"%{keyword}%" params = { "keyword": like_keyword, "user_id": user_id, "dept_id": department_id, "offset": (page - 1) * size, "limit": size } where_clauses = [] if not is_admin: where_clauses.append("(r.type = 'public' OR r.dept_id = :dept_id)") where_clauses.append(""" EXISTS ( SELECT 1 FROM room_members rm WHERE rm.room_id = r.id AND rm.user_id = :user_id ) """) where_sql = " AND ".join(where_clauses) if where_sql: where_sql = "AND " + where_sql total_result = await db.execute(text(f""" SELECT COUNT(*) AS total FROM rooms r WHERE (r.name LIKE :keyword OR r.description LIKE :keyword) {where_sql} """), params) total = total_result.scalar() data_result = await db.execute(text(f""" SELECT r.id, r.name, r.type, r.description, r.creator_id, u.account AS creator_account, r.created_at, COUNT(rm.user_id) AS member_count FROM rooms r JOIN users u ON r.creator_id = u.id LEFT JOIN room_members rm ON r.id = rm.room_id WHERE (r.name LIKE :keyword OR r.description LIKE :keyword) {where_sql} GROUP BY r.id ORDER BY r.created_at DESC LIMIT :offset, :limit """), params) rooms = [dict(row._mapping) for row in data_result.fetchall()] return total, rooms async def search_users( keyword: str, current_user_id: str, is_admin: bool = False, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: like_keyword = f"%{keyword}%" params = { "keyword": like_keyword, "current_user_id": current_user_id, "offset": (page - 1) * size, "limit": size } select_fields = "u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name" from_join = "FROM users u LEFT JOIN departments d ON u.department_id = d.id" where_clause = "(u.account LIKE :keyword)" if not is_admin: where_clause += " AND u.id != :current_user_id" total_result = await db.execute(text(f""" SELECT COUNT(*) AS total {from_join} WHERE {where_clause} """), params) total = total_result.scalar() data_result = await db.execute(text(f""" SELECT {select_fields} {from_join} WHERE {where_clause} ORDER BY u.created_at DESC LIMIT :offset, :limit """), params) users = [dict(row._mapping) for row in data_result.fetchall()] return total, users async def get_hot_search_keywords( start_time: datetime, limit: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: result = await db.execute(text(""" SELECT keyword, COUNT(*) AS search_count FROM search_records WHERE search_time >= :start_time GROUP BY keyword ORDER BY search_count DESC LIMIT :limit """), {"start_time": start_time, "limit": limit}) return [dict(row._mapping) for row in result.fetchall()] async def get_user_search_history( user_id: str, limit: int = 10, days: int = 30, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: cutoff_time = datetime.now(timezone.utc) - timedelta(days=days) result = await db.execute(text(""" SELECT id, keyword, search_time FROM search_records WHERE user_id = :user_id AND search_time >= :cutoff_time ORDER BY search_time DESC LIMIT :limit """), {"user_id": user_id, "cutoff_time": cutoff_time, "limit": limit}) return [dict(row._mapping) for row in result.fetchall()] async def get_search_history_by_id( history_id: int, db: AsyncSession = Depends(get_default_db_instance) ) -> Optional[Dict]: result = await db.execute(text(""" SELECT id, user_id, keyword, search_time FROM search_records WHERE id = :history_id """), {"history_id": history_id}) row = result.fetchone() return dict(row._mapping) if row else None async def delete_search_history( history_id: int, db: AsyncSession = Depends(get_default_db_instance) ) -> None: await db.execute(text("DELETE FROM search_records WHERE id = :history_id"), {"history_id": history_id}) async def clear_user_search_history( user_id: str, db: AsyncSession = Depends(get_default_db_instance) ) -> None: await db.execute(text("DELETE FROM search_records WHERE user_id = :user_id"), {"user_id": user_id}) async def recommend_shares_by_keywords( keywords: List[str], limit: int = 5, exclude_user_id: Optional[str] = None, department_id: Optional[int] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: conditions = " OR ".join([f"s.title LIKE :k{i} OR s.content LIKE :k{i}" for i in range(len(keywords))]) params = {f"k{i}": f"%{kw}%" for i, kw in enumerate(keywords)} params["limit"] = limit if exclude_user_id: params["exclude_user_id"] = exclude_user_id filters = " AND s.author_id != :exclude_user_id" if exclude_user_id else "" order_by = "CASE" for i, kw in enumerate(keywords): order_by += f" WHEN s.title LIKE '%{kw}%' THEN {i}" order_by += f" WHEN s.content LIKE '%{kw}%' THEN {i + len(keywords)}" order_by += " ELSE 99 END" result = await db.execute(text(f""" SELECT s.*, u.account AS author_account, c.name AS ai_char_name FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id WHERE ({conditions}) AND s.is_public = TRUE {filters} ORDER BY {order_by}, s.like_count DESC LIMIT :limit """), params) return [dict(row._mapping) for row in result.fetchall()] async def get_dept_hot_shares( dept_id: int, limit: int = 5, exclude_user_id: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: params = {"dept_id": dept_id, "limit": limit} exclude_clause = " AND s.author_id != :exclude_user_id" if exclude_user_id else "" if exclude_user_id: params["exclude_user_id"] = exclude_user_id result = await db.execute(text(f""" SELECT s.*, u.account AS author_account FROM shares s JOIN users u ON s.author_id = u.id WHERE u.department_id = :dept_id AND s.is_public = TRUE {exclude_clause} ORDER BY s.like_count DESC, s.view_count DESC LIMIT :limit """), params) return [dict(row._mapping) for row in result.fetchall()] async def get_global_hot_shares( limit: int = 5, exclude_user_id: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: params = {"limit": limit} exclude_clause = " AND s.author_id != :exclude_user_id" if exclude_user_id else "" if exclude_user_id: params["exclude_user_id"] = exclude_user_id result = await db.execute(text(f""" SELECT s.*, u.account AS author_account FROM shares s JOIN users u ON s.author_id = u.id WHERE s.is_public = TRUE {exclude_clause} ORDER BY s.like_count DESC, s.view_count DESC LIMIT :limit """), params) return [dict(row._mapping) for row in result.fetchall()] # -------------------------- 管理员统计操作 -------------------------- async def get_user_stats( start_date: date, end_date: date, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """用户统计(总数、新增数、角色分布)""" # 1. 总用户数 total_result = await db.execute(text("SELECT COUNT(*) AS total FROM users")) total = total_result.scalar() # 2. 时间范围内新增用户数 new_result = await db.execute( text(""" SELECT COUNT(*) AS new_count FROM users WHERE DATE(created_at) BETWEEN :start AND :end """), {"start": start_date, "end": end_date} ) new_count = new_result.scalar() # 3. 角色分布 role_result = await db.execute( text(""" SELECT role, COUNT(*) AS count FROM users GROUP BY role """) ) role_dist = [dict(row._mapping) for row in role_result.fetchall()] # 4. 院系分布(前10) dept_result = await db.execute( text(""" SELECT d.name AS dept_name, COUNT(u.id) AS user_count FROM departments d LEFT JOIN users u ON d.id = u.department_id GROUP BY d.id ORDER BY user_count DESC LIMIT 10 """) ) dept_dist = [dict(row._mapping) for row in dept_result.fetchall()] return { "total_user": total, "new_user": new_count, "role_distribution": role_dist, "department_distribution": dept_dist } async def get_share_stats( start_date: date, end_date: date, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """分享统计(总数、新增数、类型分布、互动统计)""" # 1. 总分享数 total_result = await db.execute(text("SELECT COUNT(*) AS total FROM shares")) total = total_result.scalar() # 2. 时间范围内新增分享数 new_result = await db.execute( text(""" SELECT COUNT(*) AS new_count FROM shares WHERE DATE(created_at) BETWEEN :start AND :end """), {"start": start_date, "end": end_date} ) new_count = new_result.scalar() # 3. 分享类型分布 type_result = await db.execute( text(""" SELECT type, COUNT(*) AS count FROM shares GROUP BY type """) ) type_dist = [dict(row._mapping) for row in type_result.fetchall()] # 4. AI角色关联分布(前10) ai_result = await db.execute( text(""" SELECT c.name AS ai_char_name, COUNT(s.id) AS share_count FROM characters c LEFT JOIN shares s ON c.id = s.ai_character_id WHERE s.ai_character_id IS NOT NULL GROUP BY c.id ORDER BY share_count DESC LIMIT 10 """) ) ai_dist = [dict(row._mapping) for row in ai_result.fetchall()] # 5. 总互动数(点赞+评论) interact_result = await db.execute( text(""" SELECT SUM(like_count) AS total_like, SUM(comment_count) AS total_comment FROM shares """) ) interact = dict(interact_result.fetchone()._mapping) return { "total_share": total, "new_share": new_count, "type_distribution": type_dist, "ai_character_distribution": ai_dist, "total_interaction": interact } 采用内部直接共用全局db,并修复漏洞,给我完整代码
11-08
from pathlib import Path from fastapi import APIRouter, Depends, Query, HTTPException, Request from fastapi.responses import JSONResponse from datetime import datetime, timezone import logging from backend.jwt_handler import TokenData, get_current_user_token_data from backend import database router = APIRouter(prefix="/api/search", tags=["搜索推荐"]) logger = logging.getLogger(__name__) @router.get("") async def search( request: Request, keyword: str = Query(..., min_length=1, max_length=100), type: str = Query("all", pattern="^(all|share|room|user)$"), page: int = Query(1, ge=1), size: int = Query(10, ge=1, le=100), current_user: TokenData = Depends(get_current_user_token_data) ): """搜索功能(支持分享/聊天室/用户多类型搜索)""" client_ip = request.client.host logger.info( f"🔍 用户发起搜索 | 用户ID:{current_user.user_id} | 关键词:{keyword} | " f"类型:{type} | 分页:第{page}页/每页{size}条 | IP:{client_ip}" ) try: # 记录搜索关键词(带时间戳) await database.add_search_record( keyword=keyword, user_id=current_user.user_id, search_time=datetime.now(timezone.utc) ) # 根据类型搜索不同内容 share_results = (0, []) if type == "all" or type == "share": # 分享搜索:公开+本人私有可见 share_results = await database.search_shares( keyword=keyword, is_public=True, author_id=current_user.user_id, page=page, size=size ) room_results = (0, []) if type == "all" or type == "room": # 聊天室搜索:公开+本人所属院系+已加入的 room_results = await database.search_rooms( keyword=keyword, user_id=current_user.user_id, department_id=current_user.department_id, is_admin=current_user.role == "admin", page=page, size=size ) user_results = (0, []) if type == "all" or type == "user": # 用户搜索:支持按账号/姓名匹配,隐藏敏感信息 user_results = await database.search_users( keyword=keyword, current_user_id=current_user.user_id, is_admin=current_user.role == "admin", page=page, size=size ) logger.debug( f"✅ 搜索完成 | 用户ID:{current_user.user_id} | 分享结果:{share_results[0]}条 | " f"聊天室结果:{room_results[0]}条 | 用户结果:{user_results[0]}条" ) return JSONResponse({ "success": True, "data": { "keyword": keyword, "results": { "shares": { "total": share_results[0], "items": share_results[1] }, "rooms": { "total": room_results[0], "items": room_results[1] }, "users": { "total": user_results[0], "items": user_results[1] } } } }) except Exception as e: logger.error( f"💥 搜索失败 | 用户ID:{current_user.user_id} | 关键词:{keyword} | " f"错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"搜索失败:服务器内部错误({str(e)})" ) @router.get("/hot") async def get_hot_keywords( request: Request, limit: int = Query(10, ge=1, le=50), current_user: TokenData = Depends(get_current_user_token_data) ): """获取热搜词(按近7天搜索量排序)""" client_ip = request.client.host logger.info( f"🔥 用户获取热搜 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}" ) try: # 获取近7天的热搜词 days_ago = datetime.now(timezone.utc) - timezone.timedelta(days=7) hot_keywords = await database.get_hot_search_keywords( start_time=days_ago, limit=limit ) # 格式化结果(添加排名) formatted_hot = [ { "rank": idx + 1, "keyword": item["keyword"], "search_count": item["search_count"], "trend": item.get("trend", "stable") # 趋势:up/down/stable } for idx, item in enumerate(hot_keywords) ] logger.debug(f"✅ 获取热搜成功 | 数量:{len(formatted_hot)} | IP:{client_ip}") return JSONResponse({ "success": True, "data": formatted_hot }) except Exception as e: logger.error( f"💥 获取热搜失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"获取热搜失败:服务器内部错误({str(e)})" ) @router.get("/recommend/shares") async def recommend_shares( request: Request, limit: int = Query(5, ge=1, le=20), current_user: TokenData = Depends(get_current_user_token_data) ): """推荐分享(基于用户搜索历史和院系热门内容)""" client_ip = request.client.host logger.info( f"📊 用户获取分享推荐 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}" ) try: # 多维度推荐策略 recommendations = [] # 1. 基于用户搜索历史推荐 user_history = await database.get_user_search_history( user_id=current_user.user_id, limit=5, days=14 # 只取近14天历史 ) history_recommendations = [] if user_history: history_recommendations = await database.recommend_shares_by_keywords( keywords=[h["keyword"] for h in user_history], limit=limit // 2, # 分配一半额度 exclude_user_id=current_user.user_id, department_id=current_user.department_id ) recommendations.extend(history_recommendations) # 2. 补充院系热门内容(当历史推荐不足时) remaining = limit - len(recommendations) if remaining > 0: dept_hot = await database.get_dept_hot_shares( dept_id=current_user.department_id, limit=remaining, exclude_user_id=current_user.user_id ) recommendations.extend(dept_hot) # 3. 最终补充全局热门(仍不足时) remaining = limit - len(recommendations) if remaining > 0: global_hot = await database.get_global_hot_shares( limit=remaining, exclude_user_id=current_user.user_id ) recommendations.extend(global_hot) # 去重(避免多渠道推荐重复内容) seen_ids = set() unique_recommendations = [] for rec in recommendations: if rec["id"] not in seen_ids: seen_ids.add(rec["id"]) unique_recommendations.append(rec) if len(unique_recommendations) >= limit: break logger.debug( f"✅ 分享推荐完成 | 用户ID:{current_user.user_id} | 推荐数量:{len(unique_recommendations)} | " f"历史推荐:{len(history_recommendations)} | IP:{client_ip}" ) return JSONResponse({ "success": True, "data": unique_recommendations }) except Exception as e: logger.error( f"💥 分享推荐失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"获取推荐失败:服务器内部错误({str(e)})" ) @router.get("/history") async def get_search_history( request: Request, limit: int = Query(10, ge=1, le=50), current_user: TokenData = Depends(get_current_user_token_data) ): """获取用户搜索历史""" client_ip = request.client.host logger.info( f"📜 用户获取搜索历史 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}" ) try: history = await database.get_user_search_history( user_id=current_user.user_id, limit=limit, days=30 # 保留30天内历史 ) # 格式化时间 formatted_history = [ { "keyword": item["keyword"], "search_time": item["search_time"].strftime("%Y-%m-%d %H:%M:%S"), "id": item["id"] } for item in history ] logger.debug(f"✅ 获取搜索历史成功 | 数量:{len(formatted_history)} | IP:{client_ip}") return JSONResponse({ "success": True, "data": formatted_history }) except Exception as e: logger.error( f"💥 获取搜索历史失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"获取搜索历史失败:服务器内部错误({str(e)})" ) @router.delete("/history/{history_id}") async def delete_search_history( request: Request, history_id: int = Path(..., ge=1), current_user: TokenData = Depends(get_current_user_token_data) ): """删除单条搜索历史""" client_ip = request.client.host logger.info( f"🗑️ 用户删除搜索历史 | 用户ID:{current_user.user_id} | 历史ID:{history_id} | IP:{client_ip}" ) try: # 校验所有权 history_item = await database.get_search_history_by_id(history_id) if not history_item: raise HTTPException(status_code=404, detail="搜索历史记录不存在") if history_item["user_id"] != current_user.user_id: raise HTTPException(status_code=403, detail="无权限删除该记录") # 执行删除 await database.delete_search_history(history_id) logger.debug(f"✅ 删除搜索历史成功 | 历史ID:{history_id} | IP:{client_ip}") return JSONResponse({ "success": True, "message": "搜索历史已删除" }) except HTTPException: raise except Exception as e: logger.error( f"💥 删除搜索历史失败 | 用户ID:{current_user.user_id} | 历史ID:{history_id} | " f"错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"删除搜索历史失败:服务器内部错误({str(e)})" ) @router.delete("/history") async def clear_search_history( request: Request, current_user: TokenData = Depends(get_current_user_token_data) ): """清空用户所有搜索历史""" client_ip = request.client.host logger.info( f"🗑️ 用户清空搜索历史 | 用户ID:{current_user.user_id} | IP:{client_ip}" ) try: await database.clear_user_search_history(current_user.user_id) logger.debug(f"✅ 清空搜索历史成功 | 用户ID:{current_user.user_id} | IP:{client_ip}") return JSONResponse({ "success": True, "message": "所有搜索历史已清空" }) except Exception as e: logger.error( f"💥 清空搜索历史失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"清空搜索历史失败:服务器内部错误({str(e)})" ) search接口有多个数据库操作与现有的数据库函数不匹配,帮我找出来并分析哪些可以用现有的函数替代,给我完整代码 import re from typing import Optional, Dict, List, Tuple, AsyncGenerator from datetime import date, datetime from fastapi import Depends from fastapi.concurrency import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine,AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import text # -------------------------- 数据库核心配置 -------------------------- # 建议从 .env 加载(原配置硬编码,此处优化为环境变量读取) import os from dotenv import load_dotenv load_dotenv() DATABASE_URL = os.getenv( "DATABASE_URL", "mysql+asyncmy://root:123456@localhost/ai_roleplay?charset=utf8mb4" # 默认 fallback ) # 异步引擎配置(优化连接池参数) engine = create_async_engine( DATABASE_URL, echo=False, # 生产环境设为 False,避免日志冗余 pool_pre_ping=True, # 连接前校验,防止失效连接 pool_size=10, # 常驻连接数 max_overflow=20, # 最大临时连接数 pool_recycle=3600 # 连接超时回收(1小时) ) # 异步 Session 工厂(线程安全) AsyncSessionLocal = sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False, # 提交后不失效对象 autoflush=False # 关闭自动刷新,减少不必要 SQL ) # -------------------------- 通用工具函数 -------------------------- def get_conversation_table_name(user_id: str) -> str: """生成用户专属对话表名(防 SQL 注入)""" safe_id = "".join(c for c in str(user_id) if c.isalnum() or c == "_") return f"conversations_{safe_id}" def is_valid_table_name(table_name: str) -> bool: """校验表名合法性(仅允许 conversations_xxx 格式)""" return re.match(r'^conversations_[a-zA-Z0-9_]+$', table_name) is not None @asynccontextmanager async def get_default_db() -> AsyncGenerator[AsyncSession,None]: """ 自动创建默认数据库会话(上下文管理器,自动管理生命周期) 用于数据库函数的默认db参数,避免手动传参 """ async with AsyncSessionLocal() as db: # 基于全局工厂创建独立会话 try: yield db # 提供会话给函数使用 await db.commit() # 函数无异常则提交 except Exception as e: await db.rollback() # 异常则回滚 raise e # 重新抛出异常,让上层处理 finally: await db.close() # 无论成败都关闭会话 async def get_default_db_instance() -> AsyncSession: """ 获取默认db实例(供函数默认参数使用) 本质是触发 get_default_db() 上下文管理器,返回会话对象 """ return await anext(get_default_db()) # anext() 用于异步上下文管理器 # -------------------------- 用户基础操作 -------------------------- async def get_user_by_account(account: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过账号查询用户(用于注册时判重、登录校验)""" result = await db.execute( text(""" SELECT id, account, password, role, department_id, created_at FROM users WHERE account = :account """), {"account": account} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_user_by_id(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过用户ID查询用户(用于角色修改、权限校验)""" result = await db.execute( text(""" SELECT id, account, role, department_id, created_at FROM users WHERE id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_user_detail(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """获取用户详情(含院系名称,用于个人中心)""" result = await db.execute( text(""" SELECT u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name FROM users u LEFT JOIN departments d ON u.department_id = d.id WHERE u.id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 用户创建与更新 -------------------------- async def create_user( account: str, password: str, role: str = "user", department_id: Optional[int] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建用户(注册专用),并自动创建专属对话表""" # 1. 插入用户记录 result = await db.execute( text(""" INSERT INTO users (account, password, role, department_id, created_at) VALUES (:account, :password, :role, :dept_id, NOW()) """), { "account": account, "password": password, "role": role, "dept_id": department_id } ) user_id = result.lastrowid # 获取自增ID # 2. 创建用户专属对话表(关联 AI 角色) table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): raise ValueError(f"Invalid user ID for conversation table: {user_id}") await db.execute(text(f""" CREATE TABLE IF NOT EXISTS `{table_name}` ( id INT AUTO_INCREMENT PRIMARY KEY, character_id INT NOT NULL, # 关联 AI 角色表 user_message TEXT NOT NULL, # 用户消息 ai_message TEXT NOT NULL, # AI 回复 timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE ) ENGINE=InnoDB CHARSET=utf8mb4; """)) # 3. 返回创建的用户信息 return await get_user_by_id(db, user_id) async def update_user( user_id: str, update_params: Dict, # 支持更新:password、role、department_id db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新用户信息(动态拼接 SQL,避免冗余)""" if not update_params: return # 无参数则不执行 # 动态生成更新字段(防注入:仅允许指定字段) allowed_fields = ["password", "role", "department_id"] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return # 补充用户ID参数 params = {**update_params, "user_id": user_id} await db.execute( text(f"UPDATE users SET {set_clause} WHERE id = :user_id"), params ) # -------------------------- 用户列表与统计 -------------------------- async def get_users_list( page: int = 1, size: int = 10, role: Optional[str] = None, dept_id: Optional[int] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询用户列表(管理员专用,支持角色/院系筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if role: where_clause.append("role = :role") params["role"] = role if dept_id is not None: where_clause.append("department_id = :dept_id") params["dept_id"] = dept_id where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 查询总数(用于分页) total_result = await db.execute( text(f"SELECT COUNT(*) AS total FROM users {where_sql}"), params ) total = total_result.scalar() # 3. 查询分页数据 data_result = await db.execute( text(f""" SELECT id, account, role, department_id, created_at FROM users {where_sql} ORDER BY created_at DESC LIMIT :offset, :limit """), params ) users = [dict(row._mapping) for row in data_result.fetchall()] return total, users async def get_user_count_by_dept(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> int: """统计指定院系的用户数(删除院系前校验)""" result = await db.execute( text("SELECT COUNT(*) FROM users WHERE department_id = :dept_id"), {"dept_id": dept_id} ) return result.scalar() # -------------------------- 原登录校验修正 -------------------------- async def check_users(account: str, password: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Tuple[str, str]]: """仅校验用户账号密码(原逻辑拆分,插入用户移至 create_user)""" result = await db.execute( text("SELECT id, password FROM users WHERE account = :account"), {"account": account} ) row = result.fetchone() return (str(row.id), row.password) if row else None # -------------------------- 院系基础操作 -------------------------- async def create_department( name: str, description: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建院系(管理员专用)""" result = await db.execute( text(""" INSERT INTO departments (name, description, created_at) VALUES (:name, :desc, NOW()) """), {"name": name, "desc": description} ) dept_id = result.lastrowid return await get_department_by_id(db, dept_id) async def get_department_by_id(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询院系""" result = await db.execute( text("SELECT id, name, description, created_at FROM departments WHERE id = :dept_id"), {"dept_id": dept_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def get_department_by_name(name: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过名称查询院系(创建时判重)""" result = await db.execute( text("SELECT id, name FROM departments WHERE name = :name"), {"name": name} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 院系列表与统计 -------------------------- async def get_departments_with_user_count(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取所有院系(含用户数统计)""" result = await db.execute( text(""" SELECT d.id, d.name, d.description, d.created_at, COUNT(u.id) AS user_count FROM departments d LEFT JOIN users u ON d.id = u.department_id GROUP BY d.id ORDER BY d.created_at DESC """) ) return [dict(row._mapping) for row in result.fetchall()] # -------------------------- 院系更新与删除 -------------------------- async def update_department( dept_id: int, update_params: Dict, # 支持更新:name、description db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新院系信息(管理员专用)""" allowed_fields = ["name", "description"] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return params = {**update_params, "dept_id": dept_id} await db.execute( text(f"UPDATE departments SET {set_clause} WHERE id = :dept_id"), params ) async def delete_department(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除院系(需先确保无用户关联)""" await db.execute( text("DELETE FROM departments WHERE id = :dept_id"), {"dept_id": dept_id} ) # -------------------------- 院系专属资源查询 -------------------------- async def get_dept_exclusive_rooms(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取院系专属聊天室(仅本院系用户可见)""" result = await db.execute( text(""" SELECT id, name, description, creator_id, created_at FROM rooms WHERE type = 'dept' AND dept_id = :dept_id ORDER BY created_at DESC """), {"dept_id": dept_id} ) return [dict(row._mapping) for row in result.fetchall()] async def get_dept_exclusive_shares( dept_id: int, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """获取院系专属分享(分页,仅本院系用户可见)""" # 1. 统计总数 total_result = await db.execute( text(""" SELECT COUNT(*) AS total FROM shares s JOIN users u ON s.author_id = u.id WHERE s.type = 'dept' AND u.department_id = :dept_id """), {"dept_id": dept_id} ) total = total_result.scalar() # 2. 查询分页数据 data_result = await db.execute( text(""" SELECT s.*, u.account AS author_account FROM shares s JOIN users u ON s.author_id = u.id WHERE s.type = 'dept' AND u.department_id = :dept_id ORDER BY s.created_at DESC LIMIT :offset, :limit """), { "dept_id": dept_id, "offset": (page - 1) * size, "limit": size } ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares # -------------------------- AI角色操作 -------------------------- async def get_all_characters(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]: """获取所有AI角色(用于聊天页面角色选择)""" result = await db.execute( text("SELECT id, name, trait, avatar_url FROM characters ORDER BY name ASC") ) return [dict(row._mapping) for row in result.fetchall()] async def get_character_by_id(character_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询AI角色(聊天时获取角色设定)""" result = await db.execute( text("SELECT id, name, trait FROM characters WHERE id = :character_id"), {"character_id": character_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 对话历史操作 -------------------------- async def save_conversation( user_id: int, character_id: int, user_message: str, ai_message: str, db: AsyncSession = Depends(get_default_db_instance) ) -> None: """保存用户与AI的对话(聊天后存储)""" table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): raise ValueError(f"Invalid user ID: {user_id}") await db.execute( text(f""" INSERT INTO `{table_name}` (character_id, user_message, ai_message) VALUES (:char_id, :user_msg, :ai_msg) """), { "char_id": character_id, "user_msg": user_message, "ai_msg": ai_message } ) async def load_conversation_history( user_id: str, character_id: Optional[int] = None, max_count: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: """加载用户对话历史(支持按AI角色筛选)""" table_name = get_conversation_table_name(user_id) if not is_valid_table_name(table_name): return [] # 表名无效则返回空历史 # 构建筛选条件(可选按角色筛选) where_clause = "WHERE character_id = :char_id" if character_id else "" params = {"limit": max_count} if character_id: params["char_id"] = character_id # 查询最近的 max_count 条历史(时间正序) result = await db.execute( text(f""" SELECT user_message, ai_message, timestamp FROM `{table_name}` {where_clause} ORDER BY timestamp DESC LIMIT :limit """), params ) rows = result.fetchall() # 转换为 [{user: ..., ai: ...}, ...] 格式,按时间正序排列 history = [ { "user": row.user_message, "ai": row.ai_message, "time": row.timestamp.strftime("%Y-%m-%d %H:%M:%S") } for row in reversed(rows) # 反转后变为时间正序 ] return history # -------------------------- 用户个性化设定 -------------------------- async def get_user_profile(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """获取用户个性化设定(如自定义角色配置)""" result = await db.execute( text(""" SELECT personality, role_setting FROM user_profiles WHERE user_id = :user_id """), {"user_id": user_id} ) row = result.fetchone() return dict(row._mapping) if row else None async def create_or_update_user_profile( user_id: str, personality: str, role_setting: str, db: AsyncSession = Depends(get_default_db_instance) ) -> bool: """创建/更新用户个性化设定(存在则更新,不存在则创建)""" await db.execute( text(""" INSERT INTO user_profiles (user_id, personality, role_setting, updated_at) VALUES (:user_id, :personality, :role_setting, NOW()) ON DUPLICATE KEY UPDATE personality = VALUES(personality), role_setting = VALUES(role_setting), updated_at = NOW() """), { "user_id": user_id, "personality": personality.strip(), "role_setting": role_setting.strip() } ) return True # -------------------------- 聊天室基础操作 -------------------------- async def create_room( name: str, type: str, # 类型:public(公开)、dept(院系)、ai(AI专属) creator_id: str, dept_id: Optional[int] = None, ai_character_id: Optional[int] = None, description: Optional[str] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """创建聊天室(支持三种类型)""" result = await db.execute( text(""" INSERT INTO rooms ( name, type, dept_id, ai_character_id, description, creator_id, created_at ) VALUES (:name, :type, :dept_id, :ai_char_id, :desc, :creator_id, NOW()) """), { "name": name, "type": type, "dept_id": dept_id, "ai_char_id": ai_character_id, "desc": description, "creator_id": creator_id } ) room_id = result.lastrowid return await get_room_by_id(db, room_id) async def get_room_by_id(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询聊天室详情""" result = await db.execute( text(""" SELECT r.*, u.account AS creator_account, c.name AS ai_char_name # 关联AI角色名称(若为AI专属) FROM rooms r JOIN users u ON r.creator_id = u.id LEFT JOIN characters c ON r.ai_character_id = c.id WHERE r.id = :room_id """), {"room_id": room_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 聊天室列表查询 -------------------------- async def get_rooms( type: Optional[str] = None, dept_id: Optional[int] = None, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询聊天室(支持按类型、院系筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if type: where_clause.append("r.type = :type") params["type"] = type if dept_id is not None: where_clause.append("r.dept_id = :dept_id") params["dept_id"] = dept_id where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 统计总数 total_result = await db.execute( text(f""" SELECT COUNT(*) AS total FROM rooms r {where_sql} """), params ) total = total_result.scalar() # 3. 查询分页数据 data_result = await db.execute( text(f""" SELECT r.id, r.name, r.type, r.description, r.creator_id, u.account AS creator_account, r.created_at, COUNT(rm.user_id) AS member_count FROM rooms r JOIN users u ON r.creator_id = u.id LEFT JOIN room_members rm ON r.id = rm.room_id {where_sql} GROUP BY r.id ORDER BY r.created_at DESC LIMIT :offset, :limit """), params ) rooms = [dict(row._mapping) for row in data_result.fetchall()] return total, rooms # -------------------------- 聊天室成员管理 -------------------------- async def check_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool: """校验用户是否为聊天室成员""" result = await db.execute( text(""" SELECT 1 FROM room_members WHERE room_id = :room_id AND user_id = :user_id """), {"room_id": room_id, "user_id": user_id} ) return result.scalar() is not None async def add_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """添加用户到聊天室成员(加入聊天室)""" await db.execute( text(""" INSERT IGNORE INTO room_members (room_id, user_id, joined_at) VALUES (:room_id, :user_id, NOW()) """), {"room_id": room_id, "user_id": user_id} ) async def remove_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """从聊天室移除用户(离开聊天室)""" await db.execute( text(""" DELETE FROM room_members WHERE room_id = :room_id AND user_id = :user_id """), {"room_id": room_id, "user_id": user_id} ) # -------------------------- 聊天室消息操作 -------------------------- async def create_room_message( room_id: int, sender_id: str, content: str, sent_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发送聊天室消息(成员专用)""" sent_at = sent_at or datetime.now() result = await db.execute( text(""" INSERT INTO room_messages ( room_id, sender_id, content, sent_at ) VALUES (:room_id, :sender_id, :content, :sent_at) """), { "room_id": room_id, "sender_id": sender_id, "content": content, "sent_at": sent_at } ) msg_id = result.lastrowid # 返回消息详情(含发送者账号) msg_result = await db.execute( text(""" SELECT rm.id, rm.content, rm.sent_at, u.account AS sender_account FROM room_messages rm JOIN users u ON rm.sender_id = u.id WHERE rm.id = :msg_id """), {"msg_id": msg_id} ) return dict(msg_result.fetchone()._mapping) async def get_room_messages( room_id: int, page: int = 1, size: int = 20, order_by: str = "sent_at DESC", db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页获取聊天室历史消息(支持排序)""" # 1. 统计总数 total_result = await db.execute( text("SELECT COUNT(*) AS total FROM room_messages WHERE room_id = :room_id"), {"room_id": room_id} ) total = total_result.scalar() # 2. 查询分页数据(防排序注入:仅允许指定排序字段) valid_order = ["sent_at ASC", "sent_at DESC"] order_sql = order_by if order_by in valid_order else "sent_at DESC" data_result = await db.execute( text(f""" SELECT rm.id, rm.content, rm.sent_at, u.id AS sender_id, u.account AS sender_account FROM room_messages rm JOIN users u ON rm.sender_id = u.id WHERE rm.room_id = :room_id ORDER BY {order_sql} LIMIT :offset, :limit """), { "room_id": room_id, "offset": (page - 1) * size, "limit": size } ) messages = [dict(row._mapping) for row in data_result.fetchall()] return total, messages # -------------------------- 聊天室删除 -------------------------- async def delete_room(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除聊天室(级联删除成员和消息)""" # 1. 删除成员关联 await db.execute( text("DELETE FROM room_members WHERE room_id = :room_id"), {"room_id": room_id} ) # 2. 删除消息 await db.execute( text("DELETE FROM room_messages WHERE room_id = :room_id"), {"room_id": room_id} ) # 3. 删除聊天室本身 await db.execute( text("DELETE FROM rooms WHERE id = :room_id"), {"room_id": room_id} ) # -------------------------- 分享基础操作 -------------------------- async def create_share( title: str, content: str, author_id: str, is_public: bool = True, type: str = "public", # 类型:public(公开)、private(私有)、dept(院系) ai_character_id: Optional[int] = None, created_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发布分享(支持三种类型)""" created_at = created_at or datetime.now() result = await db.execute( text(""" INSERT INTO shares ( title, content, author_id, is_public, type, ai_character_id, view_count, like_count, comment_count, created_at ) VALUES ( :title, :content, :author_id, :is_public, :type, :ai_char_id, 0, 0, 0, :created_at ) """), { "title": title, "content": content, "author_id": author_id, "is_public": is_public, "type": type, "ai_char_id": ai_character_id, "created_at": created_at } ) share_id = result.lastrowid return await get_share_by_id(db, share_id) async def get_share_by_id(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询分享详情(含作者信息)""" result = await db.execute( text(""" SELECT s.*, u.account AS author_account, u.department_id, c.name AS ai_char_name # 关联AI角色名称 FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id WHERE s.id = :share_id """), {"share_id": share_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 分享列表查询 -------------------------- async def get_shares( is_public: Optional[bool] = None, author_id: Optional[str] = None, type: Optional[str] = None, order_by: str = "created_at DESC", page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页查询分享(支持按公开性、作者、类型筛选)""" # 1. 构建筛选条件 where_clause = [] params = {"offset": (page - 1) * size, "limit": size} if is_public is not None: where_clause.append("s.is_public = :is_public") params["is_public"] = is_public if author_id: where_clause.append("s.author_id = :author_id") params["author_id"] = author_id if type: where_clause.append("s.type = :type") params["type"] = type where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else "" # 2. 统计总数 total_result = await db.execute( text(f"SELECT COUNT(*) AS total FROM shares s {where_sql}"), params ) total = total_result.scalar() # 3. 查询分页数据(防排序注入) valid_order = ["created_at DESC", "created_at ASC", "like_count DESC", "view_count DESC"] order_sql = order_by if order_by in valid_order else "created_at DESC" data_result = await db.execute( text(f""" SELECT s.*, u.account AS author_account, c.name AS ai_char_name FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id {where_sql} ORDER BY {order_sql} LIMIT :offset, :limit """), params ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares # -------------------------- 分享更新与删除 -------------------------- async def update_share( share_id: int, update_params: Dict, # 支持更新:title、content、is_public、type、ai_character_id、view_count等 db: AsyncSession = Depends(get_default_db_instance) ) -> None: """更新分享信息(作者专用)""" allowed_fields = [ "title", "content", "is_public", "type", "ai_character_id", "view_count", "like_count", "comment_count" ] set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields]) if not set_clause: return params = {**update_params, "share_id": share_id} await db.execute( text(f"UPDATE shares SET {set_clause} WHERE id = :share_id"), params ) async def delete_share(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None: """删除分享(级联删除评论和点赞)""" # 1. 删除点赞关联 await db.execute( text("DELETE FROM share_likes WHERE share_id = :share_id"), {"share_id": share_id} ) # 2. 删除评论 await db.execute( text("DELETE FROM comments WHERE share_id = :share_id"), {"share_id": share_id} ) # 3. 删除分享本身 await db.execute( text("DELETE FROM shares WHERE id = :share_id"), {"share_id": share_id} ) # -------------------------- 评论操作 -------------------------- async def create_comment( share_id: int, commenter_id: str, content: str, parent_id: Optional[int] = None, created_at: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """发表评论(支持回复父评论)""" created_at = created_at or datetime.now() result = await db.execute( text(""" INSERT INTO comments ( share_id, commenter_id, parent_id, content, created_at ) VALUES (:share_id, :commenter_id, :parent_id, :content, :created_at) """), { "share_id": share_id, "commenter_id": commenter_id, "parent_id": parent_id, "content": content, "created_at": created_at } ) comment_id = result.lastrowid # 返回评论详情(含评论者账号) comm_result = await db.execute( text(""" SELECT c.id, c.content, c.parent_id, c.created_at, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.id = :comment_id """), {"comment_id": comment_id} ) return dict(comm_result.fetchone()._mapping) async def get_share_comments( share_id: int, page: int = 1, size: int = 20, order_by: str = "created_at DESC", db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """分页获取分享的评论(含子评论层级)""" # 1. 统计总数 total_result = await db.execute( text("SELECT COUNT(*) AS total FROM comments WHERE share_id = :share_id"), {"share_id": share_id} ) total = total_result.scalar() # 2. 查询分页数据(先查父评论,再关联子评论) valid_order = ["created_at ASC", "created_at DESC"] order_sql = order_by if order_by in valid_order else "created_at DESC" # 第一步:查询父评论(parent_id IS NULL) parent_result = await db.execute( text(f""" SELECT c.id, c.content, c.created_at, u.id AS commenter_id, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.share_id = :share_id AND c.parent_id IS NULL ORDER BY {order_sql} LIMIT :offset, :limit """), { "share_id": share_id, "offset": (page - 1) * size, "limit": size } ) parent_comments = [dict(row._mapping) for row in parent_result.fetchall()] parent_ids = [comm["id"] for comm in parent_comments] # 第二步:查询所有子评论(parent_id 在父评论ID列表中) child_comments = [] if parent_ids: child_result = await db.execute( text(f""" SELECT c.id, c.content, c.parent_id, c.created_at, u.id AS commenter_id, u.account AS commenter_account FROM comments c JOIN users u ON c.commenter_id = u.id WHERE c.share_id = :share_id AND c.parent_id IN :parent_ids ORDER BY {order_sql} """), { "share_id": share_id, "parent_ids": tuple(parent_ids) } ) child_comments = [dict(row._mapping) for row in child_result.fetchall()] # 第三步:构建父子评论层级 child_map = {} for child in child_comments: parent_id = child["parent_id"] if parent_id not in child_map: child_map[parent_id] = [] child_map[parent_id].append(child) # 给父评论添加子评论列表 for comm in parent_comments: comm["children"] = child_map.get(comm["id"], []) return total, parent_comments async def get_comment_by_id(comment_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]: """通过ID查询评论(校验父评论是否存在)""" result = await db.execute( text(""" SELECT id, share_id, commenter_id, parent_id FROM comments WHERE id = :comment_id """), {"comment_id": comment_id} ) row = result.fetchone() return dict(row._mapping) if row else None # -------------------------- 点赞操作 -------------------------- async def check_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool: """校验用户是否已点赞该分享""" result = await db.execute( text(""" SELECT 1 FROM share_likes WHERE share_id = :share_id AND user_id = :user_id """), {"share_id": share_id, "user_id": user_id} ) return result.scalar() is not None async def add_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """给分享点赞""" await db.execute( text(""" INSERT IGNORE INTO share_likes (share_id, user_id, liked_at) VALUES (:share_id, :user_id, NOW()) """), {"share_id": share_id, "user_id": user_id} ) async def remove_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None: """取消分享点赞""" await db.execute( text(""" DELETE FROM share_likes WHERE share_id = :share_id AND user_id = :user_id """), {"share_id": share_id, "user_id": user_id} ) # -------------------------- 搜索记录操作 -------------------------- async def add_search_record( keyword: str, user_id: Optional[str] = None, search_time: Optional[datetime] = None, db: AsyncSession = Depends(get_default_db_instance) ) -> None: """记录用户搜索行为(用于热搜统计)""" search_time = search_time or datetime.now() await db.execute( text(""" INSERT INTO search_records (keyword, user_id, search_time) VALUES (:keyword, :user_id, :search_time) """), { "keyword": keyword.strip(), "user_id": user_id, "search_time": search_time } ) async def get_hot_searches( date: Optional[date] = None, limit: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> List[Dict]: """获取热搜词TOP(默认今日,按搜索次数排序)""" date = date or datetime.now().date() result = await db.execute( text(""" SELECT keyword, COUNT(*) AS search_count FROM search_records WHERE DATE(search_time) = :date GROUP BY keyword ORDER BY search_count DESC LIMIT :limit """), {"date": date, "limit": limit} ) return [dict(row._mapping) for row in result.fetchall()] async def search_shares( keyword: str, is_public: bool = True, author_id: Optional[str] = None, page: int = 1, size: int = 10, db: AsyncSession = Depends(get_default_db_instance) ) -> Tuple[int, List[Dict]]: """搜索分享(关键词匹配标题/内容)""" # 构建模糊查询参数 like_keyword = f"%{keyword}%" params = { "keyword": like_keyword, "is_public": is_public, "offset": (page - 1) * size, "limit": size } if author_id: params["author_id"] = author_id author_clause = "AND s.author_id = :author_id" else: author_clause = "" # 1. 统计总数 total_result = await db.execute( text(f""" SELECT COUNT(*) AS total FROM shares s WHERE s.is_public = :is_public AND (s.title LIKE :keyword OR s.content LIKE :keyword) {author_clause} """), params ) total = total_result.scalar() # 2. 查询分页数据 data_result = await db.execute( text(f""" SELECT s.*, u.account AS author_account, c.name AS ai_char_name FROM shares s JOIN users u ON s.author_id = u.id LEFT JOIN characters c ON s.ai_character_id = c.id WHERE s.is_public = :is_public AND (s.title LIKE :keyword OR s.content LIKE :keyword) {author_clause} ORDER BY s.created_at DESC LIMIT :offset, :limit """), params ) shares = [dict(row._mapping) for row in data_result.fetchall()] return total, shares # -------------------------- 管理员统计操作 -------------------------- async def get_user_stats( start_date: date, end_date: date, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """用户统计(总数、新增数、角色分布)""" # 1. 总用户数 total_result = await db.execute(text("SELECT COUNT(*) AS total FROM users")) total = total_result.scalar() # 2. 时间范围内新增用户数 new_result = await db.execute( text(""" SELECT COUNT(*) AS new_count FROM users WHERE DATE(created_at) BETWEEN :start AND :end """), {"start": start_date, "end": end_date} ) new_count = new_result.scalar() # 3. 角色分布 role_result = await db.execute( text(""" SELECT role, COUNT(*) AS count FROM users GROUP BY role """) ) role_dist = [dict(row._mapping) for row in role_result.fetchall()] # 4. 院系分布(前10) dept_result = await db.execute( text(""" SELECT d.name AS dept_name, COUNT(u.id) AS user_count FROM departments d LEFT JOIN users u ON d.id = u.department_id GROUP BY d.id ORDER BY user_count DESC LIMIT 10 """) ) dept_dist = [dict(row._mapping) for row in dept_result.fetchall()] return { "total_user": total, "new_user": new_count, "role_distribution": role_dist, "department_distribution": dept_dist } async def get_share_stats( start_date: date, end_date: date, db: AsyncSession = Depends(get_default_db_instance) ) -> Dict: """分享统计(总数、新增数、类型分布、互动统计)""" # 1. 总分享数 total_result = await db.execute(text("SELECT COUNT(*) AS total FROM shares")) total = total_result.scalar() # 2. 时间范围内新增分享数 new_result = await db.execute( text(""" SELECT COUNT(*) AS new_count FROM shares WHERE DATE(created_at) BETWEEN :start AND :end """), {"start": start_date, "end": end_date} ) new_count = new_result.scalar() # 3. 分享类型分布 type_result = await db.execute( text(""" SELECT type, COUNT(*) AS count FROM shares GROUP BY type """) ) type_dist = [dict(row._mapping) for row in type_result.fetchall()] # 4. AI角色关联分布(前10) ai_result = await db.execute( text(""" SELECT c.name AS ai_char_name, COUNT(s.id) AS share_count FROM characters c LEFT JOIN shares s ON c.id = s.ai_character_id WHERE s.ai_character_id IS NOT NULL GROUP BY c.id ORDER BY share_count DESC LIMIT 10 """) ) ai_dist = [dict(row._mapping) for row in ai_result.fetchall()] # 5. 总互动数(点赞+评论) interact_result = await db.execute( text(""" SELECT SUM(like_count) AS total_like, SUM(comment_count) AS total_comment FROM shares """) ) interact = dict(interact_result.fetchone()._mapping) return { "total_share": total, "new_share": new_count, "type_distribution": type_dist, "ai_character_distribution": ai_dist, "total_interaction": interact }
11-08
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值