map.getOrDefault(key,replace_words)的用法

本文深入探讨了Java中Map接口的getOrDefault方法,该方法为获取映射值提供了便利,当键不存在时能返回默认值,避免了空指针异常,增强了代码的健壮性和可读性。
public static void main(String[] args) {
		// TODO Auto-generated method stub
		Map<String, String> testMap=new HashMap<String, String>();
		testMap.put("段兿", "澳大利亚");
		testMap.put("程思莺", "澳大利亚");
		testMap.put("大佬项", "中国");
		String address1=testMap.getOrDefault("段兿", "中国");
		String address2=testMap.getOrDefault("江小白", "中国");
		System.out.println(address1);
		System.out.println(address2);
	}

map.getOrDefault(key,replace_words)其实就相当于map.get(key)的升级版,如果map中含有对应的key,则相当于get(),如果没有则输出replace_words

输出结果:

import torch import torch.nn as nn import torch.optim as optim import numpy as np import collections import os import time import matplotlib.pyplot as plt # 设置超参数 start_token = 'G' # 诗歌起始标记 end_token = 'E' # 诗歌结束标记 batch_size = 64 # 训练批量大小 embedding_dim = 128 # 词向量维度 hidden_dim = 256 # LSTM隐藏层维度 learning_rate = 0.001 # 学习率 num_epochs = 50 # 训练轮数 # 设备配置(自动选择GPU或CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 词嵌入层 class WordEmbedding(nn.Module): def __init__(self, vocab_size, embedding_dim): super(WordEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) nn.init.uniform_(self.embedding.weight, -1.0, 1.0) # 均匀初始化 def forward(self, x): return self.embedding(x) # RNN模型 class RNN_Model(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3): super(RNN_Model, self).__init__() self.embedding = WordEmbedding(vocab_size, embedding_dim) self.lstm = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0 ) self.fc = nn.Linear(hidden_dim, vocab_size) self.dropout = nn.Dropout(dropout) self.init_weights() def init_weights(self): # 全连接层初始化 nn.init.xavier_uniform_(self.fc.weight) nn.init.zeros_(self.fc.bias) # LSTM权重初始化 for name, param in self.lstm.named_parameters(): if 'weight' in name: nn.init.orthogonal_(param) elif 'bias' in name: nn.init.zeros_(param) # 设置遗忘门偏置为1(有助于缓解梯度消失) n = param.size(0) param.data[n // 4:n // 2].fill_(1.0) def forward(self, x, hidden=None): # x: (batch_size, seq_len) batch_size = x.size(0) # 嵌入层 embeds = self.embedding(x) # (batch_size, seq_len, embedding_dim) # LSTM层 if hidden is None: h0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(device) c0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(device) hidden = (h0, c0) lstm_out, hidden = self.lstm(embeds, hidden) # (batch_size, seq_len, hidden_dim) # 应用dropout,防止过拟合 lstm_out = self.dropout(lstm_out) # 全连接层 output = self.fc(lstm_out) # (batch_size, seq_len, vocab_size) # 重新排列维度用于损失计算 output = output.permute(0, 2, 1) # (batch_size, vocab_size, seq_len) return output, hidden # 数据处理函数 def process_poems2(file_name): poems = [] with open(file_name, "r", encoding='utf-8') as f: for line in f.readlines(): try: line = line.strip() if line: content = line.replace(' '' ', '').replace(',', '').replace('。', '') if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \ start_token in content or end_token in content: continue if len(content) < 5 or len(content) > 80: continue content = start_token + content + end_token poems.append(content) except ValueError as e: pass # 按诗的字数排序 poems = sorted(poems, key=lambda line: len(line)) # 统计每个字出现次数 all_words = [] for poem in poems: all_words += [word for word in poem] counter = collections.Counter(all_words) # 统计词和词频。 count_pairs = sorted(counter.items(), key=lambda x: -x[1]) # 排序 words, _ = zip(*count_pairs) words = words[:len(words)] + (' ',) word_int_map = dict(zip(words, range(len(words)))) poems_vector = [list(map(word_int_map.get, poem)) for poem in poems] return poems_vector, word_int_map, words # 批量生成函数 def create_batches(poems_idx, batch_size, max_seq_len=50): # 按长度排序,有助于减少填充 sorted_poems = sorted(poems_idx, key=len) batches = [] num_batches = len(sorted_poems) // batch_size for i in range(num_batches): start_idx = i * batch_size end_idx = (i + 1) * batch_size batch = sorted_poems[start_idx:end_idx] # 找到本批次最大长度 max_len = min(max(len(poem) for poem in batch), max_seq_len) # 填充序列 padded_batch = [] for poem in batch: if len(poem) > max_len: padded_poem = poem[:max_len] # 截断 else: padded_poem = poem + [0] * (max_len - len(poem)) # 填充,0是<PAD> padded_batch.append(padded_poem) # 输入和输出序列 inputs = [poem[:-1] for poem in padded_batch] # 输入序列(去除最后一个字符) targets = [poem[1:] for poem in padded_batch] # 输出序列(移除第一个字符) batches.append((inputs, targets)) return batches # 训练函数 def train_model(): # 处理数据 poems_idx, char_to_idx, idx_to_char = process_poems2('./tangshi.txt') vocab_size = len(char_to_idx) print(f"Vocabulary size: {vocab_size}") print(f"Number of poems: {len(poems_idx)}") # 创建批次 batches = create_batches(poems_idx, batch_size) print(f"Number of batches: {len(batches)}") # 初始化模型 model = RNN_Model( vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim, num_layers=2 ).to(device) # 模型加载 pretrained_model_path = './models/best_poetry_model' best_model_path = './models/best_tangshi_model.pth' # 确保模型目录存在 os.makedirs('./models', exist_ok=True) # 检查预训练模型是否能加载 model_loaded = False if os.path.exists(pretrained_model_path): try: # 尝试加载预训练模型 model.load_state_dict(torch.load(pretrained_model_path)) print(f"Loaded pretrained model from: {pretrained_model_path}") model_loaded = True except Exception as e: print(f"Error loading pretrained model: {e}") if not model_loaded: print("No valid pretrained model found. Starting training from scratch.") # 损失函数和优化器 criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充索引 optimizer = optim.Adam(model.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=3, verbose=True ) # 训练记录 train_losses = [] best_loss = float('inf') print("Starting training...") start_time = time.time() for epoch in range(num_epochs): model.train() epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(batches): # 转换为张量 inputs_tensor = torch.tensor(inputs, dtype=torch.long).to(device) targets_tensor = torch.tensor(targets, dtype=torch.long).to(device) # 前向传播 optimizer.zero_grad() output, _ = model(inputs_tensor) # 计算损失 loss = criterion(output, targets_tensor) # 反向传播 loss.backward() # 梯度裁剪(防止梯度爆炸) nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 更新参数 optimizer.step() # 记录损失 epoch_loss += loss.item() # 打印进度 if batch_idx % 20 == 0: avg_loss = epoch_loss / (batch_idx + 1) elapsed = time.time() - start_time print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{len(batches)}], " f"Loss: {avg_loss:.4f}, Time: {elapsed:.2f}s") # 计算本轮平均损失 epoch_loss /= len(batches) train_losses.append(epoch_loss) scheduler.step(epoch_loss) # 打印摘要 print(f"Epoch [{epoch + 1}/{num_epochs}] completed, Avg Loss: {epoch_loss:.4f}") # 保存最佳模型 if epoch_loss < best_loss: best_loss = epoch_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'char_to_idx': char_to_idx, 'idx_to_char': idx_to_char }, best_model_path) print(f"Saved best model with loss: {best_loss:.4f} at {best_model_path}") # 绘制损失曲线 plt.figure(figsize=(10, 5)) plt.plot(train_losses, label='Training Loss') plt.title('Training Loss Over Epochs') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig('./models/tangshi_training_loss.png') plt.close() print("Training completed!") # 生成诗歌函数 def generate_poem(model, idx_to_char, char_to_idx, start_word, max_length=40): model.eval() poem = [] # 初始化输入 input_seq = torch.tensor([[char_to_idx[start_token], char_to_idx[start_word]]], dtype=torch.long).to(device) poem.extend([start_token, start_word]) # 初始化隐藏状态 hidden = None with torch.no_grad(): # 禁用梯度计算 for _ in range(max_length): # 向前传播,预测下一个字符 output, hidden = model(input_seq, hidden) # 获取最后一个字符的预测 last_output = output[:, :, -1] # (batch_size, vocab_size) # 应用温度采样/多项式采样(增加多样性) probabilities = torch.softmax(last_output, dim=1).squeeze() next_idx = torch.multinomial(probabilities, 1).item() # 检查是否结束 if next_idx == char_to_idx[end_token]: break # 添加到诗歌中 next_char = idx_to_char.get(next_idx, '<UNK>') poem.append(next_char) # 更新输入序列 input_seq = torch.tensor([[next_idx]], dtype=torch.long).to(device) return ''.join(poem) # 打印格式化的诗歌 def pretty_print_poem(poem): # 移除标记并分割句子 clean_poem = poem.replace(start_token, '').replace(end_token, '') sentences = clean_poem.split('。') # 打印非空句子 for s in sentences: if s.strip(): print(s.strip() + '。') # 保存诗歌到文件 def save_poems_to_file(poems, filename="generated_poems.txt"): with open(filename, 'w', encoding='utf-8') as f: f.write("=== 生成古诗 ===\n\n") for i, (word, poem) in enumerate(poems): f.write(f"诗歌 {i + 1} (起始词: {word})\n") clean_poem = poem.replace(start_token, '').replace(end_token, '') # 按句号分行 for char in clean_poem: f.write(char) if char in [',', '。', '!', '?']: f.write('\n') f.write("\n---------------------\n\n") print(f"诗歌已保存到: {filename}") # 主函数 if __name__ == '__main__': # 确保模型目录存在 os.makedirs('./models', exist_ok=True) # 检查是否需要训练 best_model_path = './models/best_tangshi_model.pth' if not os.path.exists(best_model_path): print("训练模型未找到,开始训练...") train_model() else: print("找到已有的训练模型,跳过训练") # 加载最佳模型并生成诗歌 print("\nGenerating poems with best model...") try: # 尝试加载模型 checkpoint = torch.load(best_model_path) print(f"成功加载模型: {best_model_path}") char_to_idx = checkpoint['char_to_idx'] idx_to_char = checkpoint['idx_to_char'] vocab_size = len(char_to_idx) model = RNN_Model( vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim, num_layers=2 ).to(device) model.load_state_dict(checkpoint['model_state_dict']) start_words = ["日", "红", "山", "夜", "湖", "君"] generated_poems = [] for word in start_words: poem = generate_poem(model, idx_to_char, char_to_idx, word) print(f"\n--- Poem starting with '{word}' ---") pretty_print_poem(poem) generated_poems.append((word, poem)) # 保存所有诗歌到文件 save_poems_to_file(generated_poems) except Exception as e: print(f"\n加载模型失败: {e}") print("请检查模型文件是否存在或尝试重新训练模型") 输出 Using device: cuda 找到已有的训练模型,跳过训练 Generating poems with best model... 成功加载模型: ./models/best_tangshi_model.pth 加载模型失败: 'tuple' object has no attribute 'get' 请检查模型文件是否存在或尝试重新训练模型 进程已结束,退出代码为 0 解决上述问题
09-24
import os import sys import cv2 from cv2 import resize import numpy as np import matplotlib.pyplot as plt import argparse from PIL import Image import torch import src.utils as utils import src.dataset as dataset import crnn.seq2seq as crnn def seq2seq_decode(encoder_out, decoder, decoder_input, decoder_hidden, max_length): decoded_words = [] alph = "ABCDEFGHIJKLMNOPQRSTUVWXYZŽŠŪ-\'" converter = utils.ConvertBetweenStringAndLabel(alph) prob = 1.0 for di in range(max_length): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_out) probs = torch.exp(decoder_output) _, topi = decoder_output.data.topk(1) ni = topi.squeeze(1) decoder_input = ni prob *= probs[:, ni] if ni == utils.EOS_TOKEN: break else: decoded_words.append(converter.decode(ni)) words = ''.join(decoded_words) prob = prob.item() return words, prob def find_median(array_vals): array_vals.sort() mid = len(array_vals) // 2 return array_vals[mid] def detect_centerline(array_vals): max_val = max(array_vals) index_list = [index for index in range(len(array_vals)) if array_vals[index] == max_val] return find_median(index_list) def rotate_image(image, angle): image_center = tuple(np.array(image.shape[1::-1]) / 2) rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR) return result def extract_peak_ranges_from_array(array_vals, minimum_val=100, minimum_range=2): start_i = None end_i = None peak_ranges = [] for i, val in enumerate(array_vals): if val >= minimum_val and start_i is None: start_i = i elif val >= minimum_val and start_i is not None: pass elif val < minimum_val and start_i is not None: end_i = i if end_i - start_i > minimum_range: peak_ranges.append((start_i, end_i)) start_i = None end_i = None elif val < minimum_val and start_i is None: pass else: raise ValueError("Cannot Parse") return peak_ranges parser = argparse.ArgumentParser() parser.add_argument('--img_path', type=str, default='', help='the path of the input image') parser.add_argument('--rot_angle', type=int, default=0, help='the global rotation image') parser.add_argument('--padding', type=int, default=10, help='paddings at the head of the image') parser.add_argument('--block_size', type=int, default=33, help='threshold for binarizing image, odd number only') parser.add_argument('--threshold', type=int, default=32, help='radius to calculate the average for thresholding, even number only') parser.add_argument('--vertical_minimum', type=int, default=800, help='minimal brightness of each VERTICAL line') parser.add_argument('--word_minimum', type=int, default=200, help='minimal brightness of each WORD') parser.add_argument('--blur', type=bool, default=False, help='apply blur to words?') parser.add_argument('--pretrained', type=int, default=1, help='which pretrained model to use') cfg = parser.parse_args() def main(): global_rot_angle = cfg.rot_angle global_padding = cfg.padding imagename = cfg.img_path if cfg.pretrained == 0: my_encoder = "./model/encoder_0.pth" my_decoder = "./model/decoder_0.pth" elif cfg.pretrained == 1: my_encoder = "./model/encoder_1.pth" my_decoder = "./model/decoder_1.pth" else: sys.exit("Unknown Pretrained Model!") print("Analyzing: "+imagename) alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZŽŠŪ-\'" print("Using Möllendorff Alphabet List: " + alphabet + "\n") # len(alphabet) + SOS_TOKEN + EOS_TOKEN num_classes = len(alphabet) + 2 transformer = dataset.ResizeNormalize(img_width=480, img_height=64) image_color = cv2.imread(imagename) image_shape = (image_color.shape[0], image_color.shape[1]) image_binary = cv2.cvtColor(image_color, cv2.COLOR_BGR2GRAY) image = cv2.rotate(image_binary, cv2.ROTATE_90_COUNTERCLOCKWISE) adaptive_threshold = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, cfg.block_size, cfg.threshold) adaptive_threshold = rotate_image(adaptive_threshold, global_rot_angle) adaptive_threshold = cv2.copyMakeBorder(adaptive_threshold, 20, 20, 20, 20, cv2.BORDER_CONSTANT, 0) adaptive_threshold = adaptive_threshold[10:adaptive_threshold.shape[0]-10, 10:adaptive_threshold.shape[1]-10] image_blur = cv2.GaussianBlur(adaptive_threshold,(3,3),cv2.BORDER_DEFAULT) cv2.imshow('Binary Image', cv2.rotate(adaptive_threshold, cv2.ROTATE_90_CLOCKWISE)) cv2.waitKey(1) vertical_sum = np.sum(image_blur, axis=1) peak_ranges = extract_peak_ranges_from_array(vertical_sum,minimum_val=cfg.vertical_minimum,minimum_range=5) img_display = np.copy(adaptive_threshold) #peak_ranges.append((peak_ranges[-1][1],adaptive_threshold.shape[0])) peak_ranges.reverse() horizontal_peak_ranges2d = [] for peak_range in peak_ranges: start_y = 0 end_y = img_display.shape[1] image_x = image_blur[peak_range[0]:peak_range[1], start_y:end_y] horizontal_sum = np.sum(image_x,axis = 0) # plt.plot(horizontal_sum, range(horizontal_sum.shape[0])) # plt.gca().invert_yaxis() # plt.show() horizontal_peak_ranges = extract_peak_ranges_from_array(horizontal_sum,minimum_val=cfg.word_minimum,minimum_range=5) horizontal_peak_ranges2d.append(horizontal_peak_ranges) for hor in horizontal_peak_ranges: cv2.rectangle(img_display, (hor[0], peak_range[0]), (hor[1], peak_range[1]), 140, 1) word_piece = adaptive_threshold[peak_range[0]:peak_range[1],hor[0]:hor[1]] if cfg.blur: word_piece = cv2.GaussianBlur(word_piece,(1,1),cv2.BORDER_DEFAULT) else: pass image_dimension = (word_piece.shape[0], word_piece.shape[1]) #cv2.imshow('Words', word_piece) #print(word_piece.shape) if image_dimension[0] < 30 or image_dimension[1] < 20: pass else: factor = 1 image_resized = cv2.resize(word_piece, (int(image_dimension[1]*factor),int(image_dimension[0]*factor)), interpolation = cv2.INTER_AREA) hor_sum = np.sum(image_resized, axis=1) ctr_line = detect_centerline(hor_sum) image_dimension_new = (image_resized.shape[0], image_resized.shape[1]) add_padding = max([ctr_line, image_dimension_new[0]-ctr_line]) # cv2.imshow('current Image', image_resized) # cv2.waitKey(0) if image_dimension_new[1]<=500: padded = cv2.copyMakeBorder(image_resized, add_padding-ctr_line, add_padding-image_dimension_new[0]+ctr_line, 0, 0, cv2.BORDER_CONSTANT, 0) else: padded = image_resized factor = 64/padded.shape[0] padded = cv2.resize(padded, (int(padded.shape[1]*factor),int(padded.shape[0]*factor)), interpolation = cv2.INTER_AREA) padded = cv2.copyMakeBorder(padded, 0, 0, global_padding, 480 - global_padding - padded.shape[0], cv2.BORDER_CONSTANT, 0) padded = Image.fromarray(np.uint8(padded)).convert('L') padded = transformer(padded) padded = padded.view(1, *padded.size()) padded = torch.autograd.Variable(padded) encoder = crnn.Encoder(1, 1024) # no dropout during inference decoder = crnn.Decoder(1024, num_classes, dropout_p=0.0, max_length=121) map_location = 'cpu' encoder.load_state_dict(torch.load(my_encoder, map_location=map_location)) decoder.load_state_dict(torch.load(my_decoder, map_location=map_location)) encoder.eval() decoder.eval() encoder_out = encoder(padded) max_length = 121 decoder_input = torch.zeros(1).long() decoder_hidden = decoder.initHidden(1) words, prob = seq2seq_decode(encoder_out, decoder, decoder_input, decoder_hidden, max_length) print(words+" ", end = '') print("\n") cv2.destroyAllWindows() cv2.imshow('Current Line', cv2.rotate(img_display, cv2.ROTATE_90_CLOCKWISE)) cv2.waitKey(1) #input("Reading Completed, Press Any Key to Exit. Ambula Baniha.") # color = (0, 0, 255) # for i, peak_range in enumerate(peak_ranges): # for horizontal_range in horizontal_peak_ranges2d[i]: # x = peak_range[0] # y = horizontal_range[0] # w = peak_range[1] # h = horizontal_range[1] # patch = adaptive_threshold[x:w,y:h] # cv2.rectangle(img_display, (y,x), (h,w), 255, 2) # # print(cnt) # # cv2.imwrite("/Users/zhuohuizhang/Downloads/ManchuOCR/Data/"+fontname+"/Result/"+'%d' %cnt + '.jpg', patch) # cnt += 1 # # cv2.imshow('Vertical Segmented Image', line_seg_blur) # cv2.waitKey(0) if __name__ == "__main__": main() 根据上面代码打包的readmanchu.exe创建一个新的图形化py脚本(使用tkinter),新的图形化py脚本可以执行readmanchu.exe和参数:--img_path(该参数是必填,该选项是选择图片)、--rot_angle(旋转角度,默认是0,选填)、--padding(图形头部的填充,默认值是10,选填)、--block_size(图像二值化阈值,仅限奇数,默认值是33,选填)、--threshold(计算阈值平均值的半径,仅限偶数,默认值32,选填)、--vertical_minimum(每条垂直线的最小亮度,默认值800,选填)、--word_minimum(每个单词的最小亮度,默认值200,选填)、--blur(对文字应用模糊效果?,默认值False)、----pretrained(使用哪个预训练模型?默认值1,int类型),要求所有参数都是通过键盘输入的方式,在输入参数后点击确认即可将readmanchu.exe打印的信息显示出来,并将打印的信息输出到当前文件夹下的.\output\选中图片的名称.txt(自动在当前目录下创建目录和文件);点击确认后就会开始调用 readmanchu.exe,实际运行效果是:(举例) readmanchu.exe --img_path .\examples\001.png --rot_angle 0 --padding 10 --block_size 33 --threshold 32 --vertical_minimum 800 --word_minimum 300 --blur False ----pretrained 1 实际就相当于后台调用cmd命令行在执行readmanchu.exe
09-28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值