python的print(flush=True)实现动态loading......效果

本文详细解析了Python中print函数的flush参数作用,通过实例展示了如何使用flush参数实现实时输出,适用于动态Loading效果和在线聊天应用。同时,介绍了flush参数在文件写入中的即时刷新功能。
部署运行你感兴趣的模型镜像

python的print(flush=True)实现动态Loading......效果

import time                                                                           
print("Loading",end = "")
for i in range(6):
    print(".",end = '')
    time.sleep(0.2)

想用以上代码实现下面动图效果,即:在Loading同一行后面每0.2秒输出1个点号,总共6个。
可是上面代码效果是;6x0.2秒后1次性输出Loading……
百度了下相关问题,综合网友的回答。问题出在:上面那样循环会堵塞输出,要等sleep全部执行完,才一并打印出全部结果。要在for循环里面的end = ""后面加上flush = True,即:

import time                                                                           
print("Loading",end = "")
for i in range(6):
    print(".",end = '',flush = True)
    time.sleep(0.2)

终于实现了下面效果。(虽然只能动态逐个输出6个点号1次,不是一直循环下去,不过核心问题算是解决了)
在这里插入图片描述
我们查看下help里面写的print

print(...)
    print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)
    ...省略...
    flush: whether to forcibly flush the stream.

有个参数flush,默认为False。那么这个参数是干什么用的?并且是怎么用的?
在这里插入图片描述
举个例子:
在线web聊天,页面会实时显示聊天的内容, 其实后台是一直在向服务器请求数据的, 正常情况下是请求完毕之后才会输出相应内容, 但是即时聊天,需要一有响应就得立即返回, flush也就起作用了。

再举个例子
我们知道print也可输出到文件。在python3 交互模式中输入:

f = open("123.txt", "w")
print("123456789", file = f)

运行后打开123.txt文件,发现“123456789”未被写入,文件内容为空。只有f.close()后才将内容写进文件中。如果加入flush = True,即上面代码改为:

f = open("123.txt", "w")
print("123456789",file = f, flush = True)

不用f.close()即可将内容写进文件中
flush参数主要是刷新, 默认flush = False,不刷新,如上面例子,print到f中的内容先存到内存中,当文件对象关闭时才把内容输出到 123.txt 中;而当flush = True时它会立即把内容刷新存到 123.txt 中。

初学python,发文只当作笔记,如果有什么问题欢迎指正

您可能感兴趣的与本文相关的镜像

Facefusion

Facefusion

AI应用

FaceFusion是全新一代AI换脸工具,无需安装,一键运行,可以完成去遮挡,高清化,卡通脸一键替换,并且Nvidia/AMD等显卡全平台支持

import os import re import sys import tkinter as tk from tkinter import ttk, Frame, Scrollbar, Canvas from tkinter import scrolledtext, messagebox import pandas as pd import numpy as np from glob import glob import seaborn as sns from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk import matplotlib.ticker as ticker import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties from matplotlib.figure import Figure plt.switch_backend("TkAgg") plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"] plt.rcParams["axes.unicode_minus"] = False def find_batch_folders(path, batch_number): target_dir = "4WCSVLog" batch_folders = [] for root, dirs, files in os.walk(path): path_components = os.path.normpath(root).split(os.sep) if target_dir in path_components and batch_number in dirs: full_path = os.path.join(root, batch_number) batch_folders.append(full_path) return batch_folders def get_csv_files_path1(folder): all_files = glob(os.path.join(folder, "*-SVP1B*.csv")) return [f for f in all_files if "HEAD" not in f] def get_csv_files_path2(folder): return glob(os.path.join(folder, "*-SVP1B*.csv")) def parse_svp1b(filename): match = re.search(r"svp1b([0-9A-Z])([0-9])([0-9])", filename, re.IGNORECASE) if match: piece, line, file_num = match.groups() piece_num = 10 + ord(piece.upper()) - ord("A") if piece.isalpha() else int(piece) return piece_num, int(line), int(file_num) return None def read_csv_file(file, start_row=0, start_col=0): df = pd.read_csv(file, header=None, skiprows=start_row) return df.iloc[:, start_col:] if start_col > 0 else df def merge_files(files, path_type): piece_dict = {} for file in files: svp_info = parse_svp1b(os.path.basename(file).lower()) if not svp_info: continue piece, line, file_num = svp_info key = (piece, line) df = read_csv_file(file, start_row=(2 if path_type == 1 else 4), start_col=(0 if path_type == 1 else 18)) if df.empty: continue if key not in piece_dict: piece_dict[key] = [] piece_dict[key].append((file_num, df)) sorted_keys = sorted(piece_dict.keys(), key=lambda x: (x[0], x[1])) global_col_counter = 1 final_dfs = [] for key in sorted_keys: file_dfs = sorted(piece_dict[key], key=lambda x: x[0]) base_df = None for file_num, df in file_dfs: if base_df is None: new_columns = [f"net{global_col_counter + i}" for i in range(len(df.columns))] global_col_counter += len(df.columns) df.columns = new_columns base_df = df else: new_columns = [f"net{global_col_counter + i}" for i in range(len(df.columns))] global_col_counter += len(df.columns) df.columns = new_columns base_df = pd.concat([base_df, df], axis=1) if base_df is not None: final_dfs.append(base_df) return pd.concat(final_dfs, axis=0, ignore_index=True) if final_dfs else pd.DataFrame() def collect_batch_data(batch_list): batch_data_dict = {} path1 = r"\\10.127.1.248\c1tst\7750" path2 = r"\\10.127.1.248\c1tst\7755" for idx, batch_number in enumerate(batch_list, 1): print(f"=== 处理第 {idx}/{len(batch_list)} 个批次:{batch_number} ===") batch_data = pd.DataFrame() folders1 = find_batch_folders(path1, batch_number) if folders1: for folder1 in folders1: files = get_csv_files_path1(folder1) if files: current_data = merge_files(files, path_type=1) batch_data = pd.concat([batch_data, current_data], axis=0, ignore_index=True) print(f"从路径1找到 {len(files)} 个数据文件") else: folders2 = find_batch_folders(path2, batch_number) if folders2: for folder2 in folders2: files = get_csv_files_path2(folder2) if files: current_data = merge_files(files, path_type=2) batch_data = pd.concat([batch_data, current_data], axis=0, ignore_index=True) print(f"从路径2找到 {len(files)} 个数据文件") else: print(f"批次 {batch_number}:未找到对应文件夹,跳过\n") continue if not batch_data.empty: batch_data_dict[batch_number] = batch_data print(f"批次 {batch_number} 数据合并完成,共 {len(batch_data)} 行\n") else: print(f"批次 {batch_number}:无有效数据,跳过\n") return batch_data_dict class RedirectText: def __init__(self, text_widget): self.text_widget = text_widget def write(self, string): self.text_widget.insert(tk.END, string) self.text_widget.see(tk.END) def flush(self): pass class NetPlotViewer: """高级NET图表查看器,支持分页和无限滚动""" def __init__(self, root, batch_data_dict, batch_list, net_list): self.root = root self.batch_data_dict = batch_data_dict self.batch_list = batch_list self.net_list = net_list # 预先计算并缓存所有NET的数据 self.net_data_dict = self._precompute_net_data() # 分页配置 self.per_page = 10 self.current_page = 0 self.total_pages = (len(net_list) + self.per_page - 1) // self.per_page # 无限滚动配置 self.loaded_count = 0 self.load_step = 10 self.loading = False self.view_mode = "pagination" # 或 "infinite_scroll" # 创建主窗口 self.create_main_window() def _precompute_net_data(self): """预先计算并缓存所有NET的数据""" net_data_dict = {} for net in self.net_list: net_data = pd.DataFrame() for batch in self.batch_list: if batch in self.batch_data_dict and net in self.batch_data_dict[batch].columns: net_data[f"批次 {batch}"] = self.batch_data_dict[batch][net] net_data_dict[net] = net_data return net_data_dict def create_main_window(self): """创建主窗口和控件""" self.plot_window = tk.Toplevel(self.root) self.plot_window.title("批次数据分布箱线图") self.plot_window.geometry("1200x800") # 创建顶部控制面板 self.create_control_panel() # 创建内容区域 self.create_content_area() # 初始显示 self.update_display() def create_control_panel(self): """创建顶部控制面板""" control_frame = tk.Frame(self.plot_window, bd=1, relief=tk.RIDGE, padx=10, pady=10) control_frame.pack(fill=tk.X, padx=10, pady=(10, 5)) # 视图模式选择 tk.Label(control_frame, text="视图模式:", font=("SimHei", 10)).pack(side=tk.LEFT, padx=5) self.view_mode_var = tk.StringVar(value="分页视图" if self.view_mode == "pagination" else "无限滚动") view_mode_menu = ttk.Combobox( control_frame, textvariable=self.view_mode_var, values=["分页视图", "无限滚动"], width=12, state="readonly" ) view_mode_menu.pack(side=tk.LEFT, padx=5) view_mode_menu.bind("<<ComboboxSelected>>", self.change_view_mode) # 分页控制 self.pagination_frame = tk.Frame(control_frame) self.pagination_frame.pack(side=tk.LEFT, padx=20) # 分页控件 self.prev_btn = tk.Button( self.pagination_frame, text="上一页", command=lambda: self.change_page(-1), state=tk.DISABLED ) self.prev_btn.pack(side=tk.LEFT, padx=5) self.page_label = tk.Label(self.pagination_frame, text="1/1", font=("SimHei", 10)) self.page_label.pack(side=tk.LEFT, padx=5) self.next_btn = tk.Button( self.pagination_frame, text="下一页", command=lambda: self.change_page(1), state=tk.NORMAL if self.total_pages > 1 else tk.DISABLED ) self.next_btn.pack(side=tk.LEFT, padx=5) # 跳转控件 tk.Label(self.pagination_frame, text="跳转到:", font=("SimHei", 9)).pack(side=tk.LEFT, padx=(20, 0)) self.page_entry = tk.Entry(self.pagination_frame, width=4, font=("SimHei", 9)) self.page_entry.pack(side=tk.LEFT, padx=2) self.page_entry.insert(0, "1") self.goto_btn = tk.Button( self.pagination_frame, text="跳转", command=self.goto_page, font=("SimHei", 9) ) self.goto_btn.pack(side=tk.LEFT, padx=2) # 每页数量选择 tk.Label(self.pagination_frame, text="每页数量:", font=("SimHei", 9)).pack(side=tk.LEFT, padx=(20, 0)) self.per_page_var = tk.StringVar(value=str(self.per_page)) self.per_page_dropdown = ttk.Combobox( self.pagination_frame, textvariable=self.per_page_var, values=["5", "10", "20", "50"], width=4, state="readonly" ) self.per_page_dropdown.pack(side=tk.LEFT, padx=2) self.per_page_dropdown.bind("<<ComboboxSelected>>", self.change_per_page) # 关闭按钮 close_btn = tk.Button( control_frame, text="关闭窗口", command=self.plot_window.destroy, bg="#f44336", fg="white", font=("SimHei", 10), ) close_btn.pack(side=tk.RIGHT, padx=10) # 状态标签 self.status_label = tk.Label(control_frame, text="", font=("SimHei", 9), fg="#666") self.status_label.pack(side=tk.RIGHT, padx=10) def create_content_area(self): """创建内容显示区域""" # 创建滚动框架 self.scroll_frame = tk.Frame(self.plot_window) self.scroll_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=(0, 10)) # 添加垂直滚动条 self.vscrollbar = tk.Scrollbar(self.scroll_frame, orient=tk.VERTICAL) self.vscrollbar.pack(side=tk.RIGHT, fill=tk.Y) # 创建画布用于滚动 self.canvas = tk.Canvas(self.scroll_frame, yscrollcommand=self.vscrollbar.set) self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) self.vscrollbar.config(command=self.canvas.yview) # 创建内容框架 self.content_frame = tk.Frame(self.canvas) self.canvas_frame = self.canvas.create_window((0, 0), window=self.content_frame, anchor="nw") # 绑定配置事件 self.content_frame.bind("<Configure>", self.on_frame_configure) # 绑定鼠标滚轮事件 self.canvas.bind_all("<MouseWheel>", self.on_mousewheel) def on_frame_configure(self, event=None): """更新滚动区域""" self.canvas.configure(scrollregion=self.canvas.bbox("all")) # 检查是否需要加载更多(无限滚动模式) if self.view_mode == "infinite_scroll": self.check_load_more() def on_mousewheel(self, event): """处理鼠标滚轮事件""" self.canvas.yview_scroll(int(-1 * (event.delta / 120)), "units") # 检查是否需要加载更多(无限滚动模式) if self.view_mode == "infinite_scroll": self.check_load_more() def change_view_mode(self, event=None): """切换视图模式""" new_mode = "pagination" if self.view_mode_var.get() == "分页视图" else "infinite_scroll" if new_mode != self.view_mode: self.view_mode = new_mode self.update_display() # 更新分页控件的可见性 if self.view_mode == "pagination": self.pagination_frame.pack(side=tk.LEFT, padx=20) else: self.pagination_frame.pack_forget() def change_per_page(self, event=None): """更改每页显示数量""" try: new_per_page = int(self.per_page_var.get()) if new_per_page != self.per_page: self.per_page = new_per_page self.total_pages = (len(self.net_list) + self.per_page - 1) // self.per_page self.current_page = 0 self.update_display() except ValueError: pass def change_page(self, delta): """翻页""" new_page = max(0, min(self.total_pages - 1, self.current_page + delta)) if new_page != self.current_page: self.current_page = new_page self.update_display() def goto_page(self): """跳转到指定页码""" try: page_num = int(self.page_entry.get()) - 1 if 0 <= page_num < self.total_pages: self.current_page = page_num self.update_display() else: messagebox.showerror("错误", f"页码必须在 1 到 {self.total_pages} 之间") except ValueError: messagebox.showerror("错误", "请输入有效的页码数字") def check_load_more(self): """检查是否需要加载更多(无限滚动模式)""" if self.view_mode != "infinite_scroll" or self.loading: return # 获取滚动位置 scroll_position = self.canvas.yview() # 如果滚动到底部附近(90%位置),加载更多 if scroll_position[1] > 0.9 and self.loaded_count < len(self.net_list): self.load_more() def load_more(self): """加载更多图表(无限滚动模式)""" if self.loading or self.loaded_count >= len(self.net_list): return self.loading = True # 显示加载状态 self.status_label.config(text=f"正在加载数据... ({self.loaded_count}/{len(self.net_list)})") self.plot_window.update() start_idx = self.loaded_count end_idx = min(self.loaded_count + self.load_step, len(self.net_list)) current_nets = self.net_list[start_idx:end_idx] # 创建加载动画 loading_label = tk.Label( self.content_frame, text="加载中...", font=("SimHei", 10, "italic"), fg="#666" ) loading_label.grid(row=self.content_frame.grid_size()[1], column=0, sticky="ew", pady=10) # 短暂延迟,让用户看到加载提示 self.plot_window.after(50, lambda: self._create_charts(current_nets, loading_label)) def _create_charts(self, nets, loading_label): """创建图表并移除加载提示""" # 移除加载提示 loading_label.destroy() # 创建图表 for net in nets: self.create_net_chart(net) self.loaded_count += len(nets) self.loading = False # 更新状态 self.status_label.config(text=f"已加载 {self.loaded_count} 个NET,共 {len(self.net_list)} 个") # 检查是否全部加载完成 if self.loaded_count >= len(self.net_list): self.status_label.config(text=f"全部加载完成,共 {len(self.net_list)} 个NET", fg="green") def create_net_chart(self, net): """为单个NET创建图表""" net_data = self.net_data_dict.get(net, pd.DataFrame()) if net_data.empty: return # 获取当前内容框架中的行数 row_idx = self.content_frame.grid_size()[1] # 创建图表框架 frame = tk.Frame(self.content_frame, bd=2, relief=tk.GROOVE, padx=10, pady=10) frame.grid(row=row_idx, column=0, sticky="nsew", padx=5, pady=5) # 添加NET标签 tk.Label(frame, text=f"NET: {net}", font=("SimHei", 12, "bold")).pack(anchor=tk.W) # 创建图表 fig = Figure(figsize=(10, 4), dpi=100) ax = fig.add_subplot(111) # 准备数据 long_data = net_data.melt(var_name="Lot", value_name="测量值") long_data["批次序号"] = long_data["Lot"].str.extract(r"(\d+)").astype(int) long_data = long_data.sort_values("批次序号") # 绘制箱线图 sns.boxplot( x="Lot", y="测量值", data=long_data, ax=ax, boxprops=dict(facecolor="lightblue", alpha=0.7), medianprops=dict(color="red", linewidth=2), ) # 设置图表样式 ax.set_title(f"{net} 各批次数据分布", fontsize=12) ax.set_xlabel("Lot", fontsize=10) ax.set_ylabel("测量值", fontsize=10) ax.tick_params(axis="x", labelsize=8, rotation=30) ax.grid(axis="y", linestyle="--", alpha=0.7) fig.tight_layout() # 标注统计值 all_values = long_data["测量值"].dropna() if not all_values.empty: data_range = all_values.max() - all_values.min() for i, col in enumerate(net_data.columns): col_values = net_data[col].dropna() if col_values.empty: continue q1 = col_values.quantile(0.25) median = col_values.median() q3 = col_values.quantile(0.75) text_offset = data_range * 0.03 ax.text(i, q1 - text_offset, f"Q1: {q1:.2f}", ha="center", fontsize=8) ax.text( i, median, f"Med: {median:.2f}", ha="center", fontsize=8, color="red", ) ax.text(i, q3 + text_offset, f"Q3: {q3:.2f}", ha="center", fontsize=8) # 嵌入图表到Tkinter canvas_plot = FigureCanvasTkAgg(fig, master=frame) canvas_plot.draw() canvas_plot.get_tk_widget().pack(fill=tk.X) # 添加工具栏 toolbar = NavigationToolbar2Tk(canvas_plot, frame) toolbar.update() canvas_plot.get_tk_widget().pack(fill=tk.X) def update_display(self): """根据当前视图模式更新显示""" # 清除当前内容 for widget in self.content_frame.winfo_children(): widget.destroy() # 重置无限滚动计数器 self.loaded_count = 0 if self.view_mode == "pagination": # 分页模式 page_idx = self.current_page start_idx = page_idx * self.per_page end_idx = min((page_idx + 1) * self.per_page, len(self.net_list)) current_nets = self.net_list[start_idx:end_idx] # 更新标题和页码标签 self.plot_window.title(f"批次数据分布箱线图 ({page_idx + 1}/{self.total_pages})") self.page_label.config(text=f"{page_idx + 1}/{self.total_pages}") # 更新按钮状态 self.prev_btn.config(state=tk.NORMAL if page_idx > 0 else tk.DISABLED) self.next_btn.config(state=tk.NORMAL if page_idx < self.total_pages - 1 else tk.DISABLED) # 创建当前页的图表 for net in current_nets: self.create_net_chart(net) # 更新状态 self.status_label.config(text=f"显示 {start_idx+1}-{end_idx} 个NET,共 {len(self.net_list)} 个") else: # 无限滚动模式 self.plot_window.title("批次数据分布箱线图 (无限滚动模式)") self.status_label.config(text=f"已加载 0 个NET,共 {len(self.net_list)} 个") # 初始加载第一组图表 self.load_more() def process_batches(): input_text = batch_entry.get().strip() if not input_text: messagebox.showwarning("输入警告", "请输入批次号后再处理!") return batch_list = re.split(r"[,\s]+", input_text) batch_list = [b for b in batch_list if b] if not batch_list: messagebox.showwarning("输入警告", "未识别到有效批次号!") return # 解析NET输入 net_text = net_entry.get().strip() net_list = [] if net_text: raw_nets = re.split(r"[,\s]+", net_text) raw_nets = [n.strip() for n in raw_nets if n.strip()] for net in raw_nets: if net.startswith("net") and net[3:].isdigit(): net_list.append(net) elif net.isdigit(): net_list.append(f"net{net}") else: print(f"警告:无效NET格式 '{net}',已忽略") net_list = list(dict.fromkeys(net_list)) net_list.sort(key=lambda x: int(x[3:])) print(f"已解析NET列表(排序后):{', '.join(net_list)}") else: # 默认显示前20个NET net_list = [f"net{i}" for i in range(1, 21)] print("未指定NET编号,将显示前20个NET") process_btn.config(state=tk.DISABLED) log_text.delete(1.0, tk.END) print(f"开始处理批次:{', '.join(batch_list)}\n") # 收集所有批次的数据 batch_data_dict = collect_batch_data(batch_list) # 如果没有找到任何批次数据 if not batch_data_dict: messagebox.showerror("错误", "未找到任何批次数据!") process_btn.config(state=tk.NORMAL) return # 如果用户没有指定NET,使用第一个批次的所有NET列 if not net_list: first_batch = next(iter(batch_data_dict.values())) net_list = [col for col in first_batch.columns if col.startswith("net")] net_list.sort(key=lambda x: int(x[3:])) print(f"自动获取NET列表: {', '.join(net_list[:20])} (显示前20个)") # 创建NET图表查看器 NetPlotViewer(root, batch_data_dict, batch_list, net_list) process_btn.config(state=tk.NORMAL) print("所有批次处理完成!") def create_gui(): global root, batch_entry, log_text, process_btn, net_entry root = tk.Tk() root.title("電測阻值箱線圖") root.geometry("800x650") root.resizable(True, True) # 批次号输入区域 input_frame = tk.Frame(root, padx=10, pady=10) input_frame.pack(fill=tk.X) tk.Label( input_frame, text="批次号输入(多个批次用逗号/空格分隔):", font=("SimHei", 10) ).pack(side=tk.LEFT) batch_entry = tk.Entry(input_frame, width=50, font=("SimHei", 10)) batch_entry.pack(side=tk.LEFT, padx=5) batch_entry.insert(0, "25AL90070800,25AL90070400") # NET编号输入 net_frame = tk.Frame(root, padx=10, pady=5) net_frame.pack(fill=tk.X) tk.Label( net_frame, text="NET编号输入(多个用逗号/空格分隔,留空显示全部):", font=("SimHei", 10), ).pack(side=tk.LEFT) net_entry = tk.Entry(net_frame, width=50, font=("SimHei", 10)) net_entry.pack(side=tk.LEFT, padx=5) net_entry.insert(0, "1,2,3,4,5") # 測試先默認 # 按钮 btn_frame = tk.Frame(root, padx=10, pady=5) btn_frame.pack(fill=tk.X) process_btn = tk.Button( btn_frame, text="开始处理", command=process_batches, bg="#4CAF50", fg="white", font=("SimHei", 10), ) process_btn.pack(side=tk.LEFT, padx=5) clear_btn = tk.Button( btn_frame, text="清空日志", command=lambda: log_text.delete(1.0, tk.END), bg="#f44336", fg="white", font=("SimHei", 10), ) clear_btn.pack(side=tk.LEFT, padx=5) # 日志 log_frame = tk.Frame(root, padx=10, pady=5) log_frame.pack(fill=tk.BOTH, expand=True) tk.Label(log_frame, text="处理日志:", font=("SimHei", 10)).pack(anchor=tk.W) log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, font=("SimHei", 9)) log_text.pack(fill=tk.BOTH, expand=True) sys.stdout = RedirectText(log_text) root.mainloop() if __name__ == "__main__": create_gui() 我現在想要先把數據收集完成後,在進行輸入net,畫圖,這樣如果我要更換查看net,就不用再重新執行程式重新收集。只要這個界面不關閉,收集一次後,可以多次輸入net進行查看
最新发布
11-28
没有这些呀from awq.quantize.quantizer import WeightQuantizer from awq.utils import get_act_scales, get_weight_scales,我不是和你说了吗,你学习一下这个流程,有依据地帮我改写from lm_eval import evaluator, tasks from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import torch import argparse import os import json from accelerate import ( init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model, ) from accelerate.utils.modeling import get_balanced_memory from awq.utils.parallel import auto_parallel from awq.quantize.pre_quant import run_awq, apply_awq from awq.quantize.quantizer import ( pseudo_quantize_model_weight, real_quantize_model_weight, ) from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.utils import simple_dispatch_model from datasets import load_dataset from torch import nn import tqdm parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, help="path of the hf model") parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16"]) parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--tasks", default=None, type=str) parser.add_argument("--output_path", default=None, type=str) parser.add_argument("--num_fewshot", type=int, default=0) # model config parser.add_argument("--parallel", action="store_true", help="enable model parallelism") # max memory to offload larger models to CPU parser.add_argument( "--max_memory", type=str, nargs="*", help="List of device_id:max_memory pairs to be parsed into a dictionary; " + "Example: 0:10GiB 1:10GiB cpu:30GiB; " + "mode details here: " + "https://huggingface.co/docs/accelerate/usage_guides/big_modeling", ) parser.add_argument( "--auto_parallel", action="store_true", help="automatically set parallel and batch_size", ) # quantization config parser.add_argument("--w_bit", type=int, default=None) parser.add_argument("--q_group_size", type=int, default=-1) parser.add_argument("--no_zero_point", action="store_true", help="disable zero_point") parser.add_argument("--q_backend", type=str, default="fake", choices=["fake", "real"]) # save/load real quantized weights parser.add_argument("--dump_quant", type=str, default=None, help="save quantized model") parser.add_argument( "--dump_fake", type=str, default=None, help="save fake-quantized model" ) parser.add_argument("--load_quant", type=str, default=None, help="load quantized model") # apply/save/load awq parser.add_argument("--run_awq", action="store_true", help="perform awq search process") parser.add_argument( "--dump_awq", type=str, default=None, help="save the awq search results" ) parser.add_argument( "--load_awq", type=str, default=None, help="load the awq search results" ) parser.add_argument( "--vila-15", action="store_true", help="quantizing vila 1.5", ) parser.add_argument( "--vila-20", action="store_true", help="quantizing or smoothing vila 2.0 (NVILA)", ) parser.add_argument( "--smooth_scale", action="store_true", help="generate the act scale of visiontower", ) parser.add_argument( "--media_path", type=str, nargs="+", help="The input video to get act scale for visiontower", ) parser.add_argument( "--act_scale_path", type=str, default=None, help="Path to save act scale", ) args = parser.parse_args() assert ( args.act_scale_path is not None and len(args.media_path) > 0 ) or not args.smooth_scale vila_10_quant_mode = ( ("llava" in args.model_path.lower() or "vila" in args.model_path.lower()) and not args.vila_15 and not args.vila_20 ) max_memory = [v.split(":") for v in (args.max_memory or [])] max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} if args.auto_parallel: gpu_list = auto_parallel(args) # get quantization config (apart from w_bit) q_config = { "zero_point": not args.no_zero_point, # by default True "q_group_size": args.q_group_size, # whether to use group quantization } print("Quantization config:", q_config) # build model and tokenizer def build_model_and_enc(model_path, dtype): torch_dtype = torch.float16 if dtype == "float16" else torch.bfloat16 if not os.path.exists(model_path): # look into ssd raise FileNotFoundError(f"{model_path} not found!") print(f"* Building model {model_path}") # all hf model if vila_10_quant_mode: from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path enc, model, image_processor, context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=get_model_name_from_path(model_path), device="cpu", **{"use_cache": False}, ) else: config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) # Note (Haotian): To avoid OOM after huggingface transformers 4.36.2 config.use_cache = False if "mpt" in config.__class__.__name__.lower(): enc = AutoTokenizer.from_pretrained( config.tokenizer_name, trust_remote_code=True ) else: enc = AutoTokenizer.from_pretrained( model_path, use_fast=False, trust_remote_code=True ) if args.load_quant: # directly load quantized weights print("Loading pre-computed quantized weights...") with init_empty_weights(): model = AutoModelForCausalLM.from_config( config=config, torch_dtype=torch_dtype, trust_remote_code=True ) real_quantize_model_weight( model, w_bit=args.w_bit, q_config=q_config, init_only=True ) model.tie_weights() # Infer device map kwargs = {"max_memory": max_memory} if len(max_memory) else {} device_map = infer_auto_device_map( model, no_split_module_classes=[ "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer", ], **kwargs, ) # Load checkpoint in the model load_checkpoint_in_model( model, checkpoint=args.load_quant, device_map=device_map, offload_state_dict=True, ) # Dispatch model model = simple_dispatch_model(model, device_map=device_map) model.eval() else: # fp16 to quantized args.run_awq &= not args.load_awq # if load_awq, no need to run awq # Init model on CPU: kwargs = {"torch_dtype": torch_dtype, "low_cpu_mem_usage": True} if not vila_10_quant_mode: model = AutoModelForCausalLM.from_pretrained( model_path, config=config, trust_remote_code=True, **kwargs ) model.eval() if args.run_awq: assert args.dump_awq, "Please save the awq results with --dump_awq" awq_results = run_awq( model, enc, w_bit=args.w_bit, q_config=q_config, n_samples=128, seqlen=512, ) if args.dump_awq: dirpath = os.path.dirname(args.dump_awq) os.makedirs(dirpath, exist_ok=True) torch.save(awq_results, args.dump_awq) print("AWQ results saved at", args.dump_awq) exit(0) if args.load_awq: print("Loading pre-computed AWQ results from", args.load_awq) awq_results = torch.load(args.load_awq, map_location="cpu") apply_awq(model, awq_results) # weight quantization if args.w_bit is not None: if args.q_backend == "fake": assert ( args.dump_quant is None ), "Need to use real quantization to dump quantized weights" pseudo_quantize_model_weight(model, w_bit=args.w_bit, q_config=q_config) if args.dump_fake: model.save_pretrained(args.dump_fake) print("Pseudo-quantized models saved at", args.dump_fake) elif args.q_backend == "real": # real quantization real_quantize_model_weight(model, w_bit=args.w_bit, q_config=q_config) if args.dump_quant: if not args.dump_quant.endswith("v2.pt"): print("[Info] Auto-change the dump_quant file name to *v2.pt") args.dump_quant = args.dump_quant.replace(".pt", "-v2.pt") dirpath = os.path.dirname(args.dump_quant) os.makedirs(dirpath, exist_ok=True) print(f"Saving the quantized model at {args.dump_quant}...") torch.save(model.cpu().state_dict(), args.dump_quant) exit(0) else: raise NotImplementedError # Move the model to GPU (as much as possible) for LM evaluation kwargs = { "max_memory": get_balanced_memory( model, max_memory if len(max_memory) > 0 else None ) } device_map = infer_auto_device_map( model, # TODO: can we remove this? no_split_module_classes=[ "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer", ], **kwargs, ) model = dispatch_model(model, device_map=device_map) return model, enc def main(): if args.output_path is not None and os.path.exists(args.output_path): # print(f"Results {args.output_path} already generated. Exit.") print(f"Results {args.output_path} already generated. Overwrite.") # exit() # a hack here to auto set model group if args.smooth_scale and args.vila_20: if os.path.exists(args.act_scale_path): print(f"Found existing Smooth Scales {args.act_scale_path}, skip.") else: from awq.quantize import get_smooth_scale act_scale = get_smooth_scale(args.model_path, args.media_path) os.makedirs(os.path.dirname(args.act_scale_path), exist_ok=True) torch.save(act_scale, args.act_scale_path) print("Save act scales at " + str(args.act_scale_path)) args.model_path = args.model_path + "/llm" if args.dump_awq is None and args.dump_quant is None: exit() if args.dump_awq and os.path.exists(args.dump_awq): print(f"Found existing AWQ results {args.dump_awq}, exit.") exit() model, enc = build_model_and_enc(args.model_path, args.dtype) if args.tasks is not None: # https://github.com/IST-DASLab/gptq/blob/2d65066eeb06a5c9ff5184d8cebdf33662c67faf/llama.py#L206 if args.tasks == "wikitext": testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt") model.seqlen = 2048 testenc = testenc.input_ids.to(model.device) nsamples = testenc.numel() // model.seqlen model = model.eval() nlls = [] for i in tqdm.tqdm(range(nsamples), desc="evaluating..."): batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( model.device ) with torch.no_grad(): lm_logits = model(batch).logits shift_logits = lm_logits[:, :-1, :].contiguous().float() shift_labels = testenc[ :, (i * model.seqlen) : ((i + 1) * model.seqlen) ][:, 1:] loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) neg_log_likelihood = loss.float() * model.seqlen nlls.append(neg_log_likelihood) ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) print(ppl.item()) results = {"ppl": ppl.item()} if args.output_path is not None: os.makedirs(os.path.dirname(args.output_path), exist_ok=True) with open(args.output_path, "w") as f: json.dump(results, f, indent=2) else: task_names = args.tasks.split(",") lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size) results = evaluator.simple_evaluate( model=lm_eval_model, tasks=task_names, batch_size=args.batch_size, no_cache=True, num_fewshot=args.num_fewshot, ) print(evaluator.make_table(results)) if args.output_path is not None: os.makedirs(os.path.dirname(args.output_path), exist_ok=True) # otherwise cannot save results["config"]["model"] = args.model_path with open(args.output_path, "w") as f: json.dump(results, f, indent=2) if __name__ == "__main__": main()
07-26
from data import * from utils.augmentations import SSDAugmentation, BaseTransform from utils.functions import MovingAverage, SavePath from utils.logger import Log from utils import timer from layers.modules import MultiBoxLoss from yolact import Yolact import os import sys import time import math, random from pathlib import Path import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn import torch.nn.init as init import torch.utils.data as data import numpy as np import argparse import datetime # Oof import eval as eval_script def str2bool(v): return v.lower() in ("yes", "true", "t", "1") parser = argparse.ArgumentParser( description='Yolact Training Script') parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from. If this is "interrupt"'\ ', the model will resume training from the interrupt file.') parser.add_argument('--start_iter', default=-1, type=int, help='Resume training at this iter. If this is -1, the iteration will be'\ 'determined from the file name.') parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading') parser.add_argument('--cuda', default=True, type=str2bool, help='Use CUDA to train model') parser.add_argument('--lr', '--learning_rate', default=None, type=float, help='Initial learning rate. Leave as None to read this from the config.') parser.add_argument('--momentum', default=None, type=float, help='Momentum for SGD. Leave as None to read this from the config.') parser.add_argument('--decay', '--weight_decay', default=None, type=float, help='Weight decay for SGD. Leave as None to read this from the config.') parser.add_argument('--gamma', default=None, type=float, help='For each lr step, what to multiply the lr by. Leave as None to read this from the config.') parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models.') parser.add_argument('--log_folder', default='logs/', help='Directory for saving logs.') parser.add_argument('--config', default=None, help='The config object to use.') parser.add_argument('--save_interval', default=10000, type=int, help='The number of iterations between saving the model.') parser.add_argument('--validation_size', default=5000, type=int, help='The number of images to use for validation.') parser.add_argument('--validation_epoch', default=2, type=int, help='Output validation information every n iterations. If -1, do no validation.') parser.add_argument('--keep_latest', dest='keep_latest', action='store_true', help='Only keep the latest checkpoint instead of each one.') parser.add_argument('--keep_latest_interval', default=100000, type=int, help='When --keep_latest is on, don\'t delete the latest file at these intervals. This should be a multiple of save_interval or 0.') parser.add_argument('--dataset', default=None, type=str, help='If specified, override the dataset specified in the config with this one (example: coco2017_dataset).') parser.add_argument('--no_log', dest='log', action='store_false', help='Don\'t log per iteration information into log_folder.') parser.add_argument('--log_gpu', dest='log_gpu', action='store_true', help='Include GPU information in the logs. Nvidia-smi tends to be slow, so set this with caution.') parser.add_argument('--no_interrupt', dest='interrupt', action='store_false', help='Don\'t save an interrupt when KeyboardInterrupt is caught.') parser.add_argument('--batch_alloc', default=None, type=str, help='If using multiple GPUS, you can set this to be a comma separated list detailing which GPUs should get what local batch size (It should add up to your total batch size).') parser.add_argument('--no_autoscale', dest='autoscale', action='store_false', help='YOLACT will automatically scale the lr and the number of iterations depending on the batch size. Set this if you want to disable that.') parser.set_defaults(keep_latest=False, log=True, log_gpu=False, interrupt=True, autoscale=True) args = parser.parse_args() if args.config is not None: set_cfg(args.config) if args.dataset is not None: set_dataset(args.dataset) if args.autoscale and args.batch_size != 8: factor = args.batch_size / 8 if __name__ == '__main__': print('Scaling parameters by %.2f to account for a batch size of %d.' % (factor, args.batch_size)) cfg.lr *= factor cfg.max_iter //= factor cfg.lr_steps = [x // factor for x in cfg.lr_steps] # Update training parameters from the config if necessary def replace(name): if getattr(args, name) == None: setattr(args, name, getattr(cfg, name)) replace('lr') replace('decay') replace('gamma') replace('momentum') # This is managed by set_lr cur_lr = args.lr if torch.cuda.device_count() == 0: print('No GPUs detected. Exiting...') exit(-1) if args.batch_size // torch.cuda.device_count() < 6: if __name__ == '__main__': print('Per-GPU batch size is less than the recommended limit for batch norm. Disabling batch norm.') cfg.freeze_bn = True loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S', 'I'] if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') class NetLoss(nn.Module): """ A wrapper for running the network and computing the loss This is so we can more efficiently use DataParallel. """ def __init__(self, net:Yolact, criterion:MultiBoxLoss): super().__init__() self.net = net self.criterion = criterion def forward(self, images, targets, masks, num_crowds): preds = self.net(images) losses = self.criterion(self.net, preds, targets, masks, num_crowds) return losses class CustomDataParallel(nn.DataParallel): """ This is a custom version of DataParallel that works better with our training data. It should also be faster than the general case. """ def scatter(self, inputs, kwargs, device_ids): # More like scatter and data prep at the same time. The point is we prep the data in such a way # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs. devices = ['cuda:' + str(x) for x in device_ids] splits = prepare_data(inputs[0], devices, allocation=args.batch_alloc) return [[split[device_idx] for split in splits] for device_idx in range(len(devices))], \ [kwargs] * len(devices) def gather(self, outputs, output_device): out = {} for k in outputs[0]: out[k] = torch.stack([output[k].to(output_device) for output in outputs]) return out def train(): if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) # Parallel wraps the underlying module, but when saving and loading we don't want that yolact_net = Yolact() net = yolact_net net.train() if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio) if args.batch_alloc is not None: args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) net = CustomDataParallel(NetLoss(net, criterion)) if args.cuda: net = net.cuda() # Initialize everything if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) # loss counters loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) last_time = time.time() epoch_size = len(dataset)+1 // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) save_path = lambda epoch, iteration: SavePath(cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order loss_avgs = { k: MovingAverage(100) for k in loss_types } print('Begin training!') print() # try-except so you can use ctrl+c to save early and stop training try: for epoch in range(num_epochs): # Resume from start_iter if (epoch+1)*epoch_size < iteration: continue for datum in data_loader: # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch+1)*epoch_size: break # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration] # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer, args.lr * (args.gamma ** step_index)) # Zero the grad to get ready to compute gradients optimizer.zero_grad() # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss) losses = net(datum) losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # Backprop loss.backward() # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer.step() # Add the loss to the moving average for bookkeeping for k in losses: loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str(datetime.timedelta(seconds=(cfg.max_iter-iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) if args.log: precision = 5 loss_info = {k: round(losses[k].item(), precision) for k in losses} loss_info['T'] = round(loss.item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu iteration += 1 if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # This is done per epoch if args.validation_epoch > 0: if epoch % args.validation_epoch == 0 and epoch > 0: compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights(save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration)) def set_lr(optimizer, new_lr): for param_group in optimizer.param_groups: param_group['lr'] = new_lr global cur_lr cur_lr = new_lr def gradinator(x): x.requires_grad = False return x def prepare_data(datum, devices:list=None, allocation:list=None): with torch.no_grad(): if devices is None: devices = ['cuda:0'] if args.cuda else ['cpu'] if allocation is None: allocation = [args.batch_size // len(devices)] * (len(devices) - 1) allocation.append(args.batch_size - sum(allocation)) # The rest might need more/less images, (targets, masks, num_crowds) = datum cur_idx = 0 for device, alloc in zip(devices, allocation): for _ in range(alloc): images[cur_idx] = gradinator(images[cur_idx].to(device)) targets[cur_idx] = gradinator(targets[cur_idx].to(device)) masks[cur_idx] = gradinator(masks[cur_idx].to(device)) cur_idx += 1 if cfg.preserve_aspect_ratio: # Choose a random size from the batch _, h, w = images[random.randint(0, len(images)-1)].size() for idx, (image, target, mask, num_crowd) in enumerate(zip(images, targets, masks, num_crowds)): images[idx], targets[idx], masks[idx], num_crowds[idx] \ = enforce_size(image, target, mask, num_crowd, w, h) cur_idx = 0 split_images, split_targets, split_masks, split_numcrowds \ = [[None for alloc in allocation] for _ in range(4)] for device_idx, alloc in enumerate(allocation): split_images[device_idx] = torch.stack(images[cur_idx:cur_idx+alloc], dim=0) split_targets[device_idx] = targets[cur_idx:cur_idx+alloc] split_masks[device_idx] = masks[cur_idx:cur_idx+alloc] split_numcrowds[device_idx] = num_crowds[cur_idx:cur_idx+alloc] cur_idx += alloc return split_images, split_targets, split_masks, split_numcrowds def no_inf_mean(x:torch.Tensor): """ Computes the mean of a vector, throwing out all inf values. If there are no non-inf values, this will return inf (i.e., just the normal mean). """ no_inf = [a for a in x if torch.isfinite(a)] if len(no_inf) > 0: return sum(no_inf) / len(no_inf) else: return x.mean() def compute_validation_loss(net, data_loader, criterion): global loss_types with torch.no_grad(): losses = {} # Don't switch to eval mode because we want to get losses iterations = 0 for datum in data_loader: images, targets, masks, num_crowds = prepare_data(datum) out = net(images) wrapper = ScatterWrapper(targets, masks, num_crowds) _losses = criterion(out, wrapper, wrapper.make_mask()) for k, v in _losses.items(): v = v.mean().item() if k in losses: losses[k] += v else: losses[k] = v iterations += 1 if args.validation_size <= iterations * args.batch_size: break for k in losses: losses[k] /= iterations loss_labels = sum([[k, losses[k]] for k in loss_types if k in losses], []) print(('Validation ||' + (' %s: %.3f |' * len(losses)) + ')') % tuple(loss_labels), flush=True) def compute_validation_map(epoch, iteration, yolact_net, dataset, log:Log=None): with torch.no_grad(): yolact_net.eval() start = time.time() print() print("Computing validation mAP (this may take a while)...", flush=True) val_info = eval_script.evaluate(yolact_net, dataset, train_mode=True) end = time.time() if log is not None: log.log('val', val_info, elapsed=(end - start), epoch=epoch, iter=iteration) yolact_net.train() def setup_eval(): eval_script.parse_args(['--no_bar', '--max_images='+str(args.validation_size)]) if __name__ == '__main__': train() 模型初始化处在哪儿
06-20
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值