SRCNN 代码解读【utils.py】

博主近期在学习SRCNN,阅读相关代码并做好笔记,还给出了代码下载链接https://github.com/tegg89/SRCNN-Tensorflow 。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近在学习SRCNN,阅读代码做好笔记
代码下载链接https://github.com/tegg89/SRCNN-Tensorflow
下面开始

"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os  #导入os库,主要用于系统命令处理
import glob #导入glob库,作用是类似于系统的文件路径匹配查询
import h5py #h5py库,主要用于读取或创建datasets或groups
import random  #随机数库,主要用于生成随机数
import matplotlib.pyplot as plt #导入matlpotlib.pyplot,画图,数据可视化

from PIL import Image  # for loading images as YCbCr format
import scipy.misc #该库主要用于将数组保存成图像形式
import scipy.ndimage #该库用于图像处理
import numpy as np

import tensorflow as tf

try:
  xrange
except:
  xrange = range    #处理异常中断
  
FLAGS = tf.app.flags.FLAGS  #命令行参数传递

def read_data(path):
  """
  Read h5 format data file
  
  Args:
    path: file path of desired file
    data: '.h5' file format that contains train data values
    label: '.h5' file format that contains train label values
    
    读取h5格式的数据文件
    
    参数:
    路径:所需文件的文件路径
     数据:”。包含训练数据值的h5'文件格式
     标签:”。包含训练标签值的h5'文件格式
      """
      with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label

def preprocess(path, scale=3):#定义预处理函数
  """
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation

  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)


    预处理单个图像文件
    (1)读取原始图像为YCbCr格式(默认灰度)
    (2)标准化
    (3)应用双三次插值的图像文件

     参数:
     路径:所需文件的文件路径
     input_:图像应用双三次插值(低分辨率)
     标签_:原始分辨率图像(高分辨率)
  """
  image = imread(path, is_grayscale=True)
  label_ = modcrop(image, scale)  #对图像进行缩放操作

  # Must be normalized
  image = image / 255.
  label_ = label_ / 255. #归一化操作

  input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)  #先把labei即清晰的原图变模糊,没有使用预滤波
  input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)  #再把变模糊的图像放大,没有使用预滤波

  return input_, label_

def prepare_data(sess, dataset):
  """
  Args:
    dataset: choose train dataset or test dataset
    
    For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
  """
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)   #路径拼接
    data = glob.glob(os.path.join(data_dir, "*.bmp")) #路径查询匹配,返回所有匹配的文件路径列表
  else:
    data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
    data = glob.glob(os.path.join(data_dir, "*.bmp"))

  return data

def make_data(sess, data, label):
  """
  Make input data as h5 file format
  Depending on 'is_train' (flag value), savepath would be changed.
  """
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)
    hf.create_dataset('label', data=label)

def imread(path, is_grayscale=True):
  """
  Read image using its path.
  Default value is gray-scale, and image is read by YCbCr format as the paper said.
  使用图像的路径读取图像。
  默认值为灰度,图像采用YCbCr格式读取。
  """
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
  else:
    return scipy.misc.imread(path, mode='YCbCr').astype(np.float)

def modcrop(image, scale=3):
  """
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  要对原始图像进行缩放,首先要做的是在缩放操作时没有余数。
  我们需要找到高度(和宽度)与比例因子的模。
  然后,从原始图像大小的高度(和宽度)中减去模。
  即使在缩放操作之后也不会有余数。
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = h - np.mod(h, scale)  #mod的作用是取余数,这里返回的是h除以scale后得到的余数
    w = w - np.mod(w, scale)  #同上
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image

def input_setup(sess, config):
  """
  Read image files and make their sub-images and saved them as a h5 file format.
  读取图像文件并生成它们的子图像,并将它们保存为h5文件格式。
  """
  # Load data path
  if config.is_train:
    data = prepare_data(sess, dataset="Train")
  else:
    data = prepare_data(sess, dataset="Test")

  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) / 2 # 6

  if config.is_train: 
    for i in range(len(data)):#一幅图作为一个data
      input_, label_ = preprocess(data[i], config.scale)#得到data[]的LR和HR图input_和label_
      if len(input_.shape) == 3:
        h, w, _ = input_.shape
      else:
        h, w = input_.shape
      #把input_和label_分割成若干自图sub_input和sub_label
      for x in range(0, h-config.image_size+1, config.stride):
        for y in range(0, w-config.image_size+1, config.stride):
          sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
          sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21]
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])#按image size大小重排 因此 imgae_size应为33 而label_size应为21
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
          sub_input_sequence.append(sub_input)#在sub_input_sequence末尾加sub_input中元素 但考虑为空
          sub_label_sequence.append(sub_label)
  else:
        #测试
        input_, label_ = preprocess(data[0], config.scale)#测试图片
        if len(input_.shape) == 3:
          h, w, _ = input_.shape
        else:
          h, w = input_.shape
        nx = 0 #后注释
        ny = 0 #后注释
        #自图需要进行合并操作
        for x in range(0, h-config.image_size+1, config.stride): #x从0到h-33+1 步长stride(21)
          nx += 1
          ny = 0
          for y in range(0, w-config.image_size+1, config.stride):#y从0到w-33+1 步长stride(21)
            ny += 1
            #分块sub_input=input_[x:x+33,y:y+33]  sub_label=label_[x+6,x+6+21, y+6,y+6+21]
            sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
            sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21] 
            sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
            sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
            sub_input_sequence.append(sub_input)
            sub_label_sequence.append(sub_label)
  # 上面的部分和训练是一样的
  arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
  arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]
  make_data(sess, arrdata, arrlabel)#存成h5格式
  if not config.is_train:
    return nx, ny 
def imsave(image, path):
  return scipy.misc.imsave(path, image)
def merge(images, size):
  h, w = images.shape[1], images.shape[2] #觉得下标应该是0,1
  #h, w = images.shape[0], images.shape[1]
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image
  return img
<think>根据用户需求,我们需要创建一个名为`utils.py`的模块,其中包含与继承和多态相关的工具函数或类。考虑到用户之前的需求是关于继承与多态的示例,我们可以设计以下内容:1.在`utils.py`中定义一个基类`Animal`和两个子类`Dog`和`Cat`,以展示继承。2.提供一个多态演示函数,如`animal_sound(animal)`,它接受一个动物对象并调用其发声方法。3.提供一个创建动物对象的工厂函数(可以使用类方法或普通函数)。同时,我们将编写一个`main.py`来导入并使用`utils.py`模块中的类和方法。##设计`utils.py`模块```python#utils.pyclassAnimal:"""动物基类"""def__init__(self,name,age):self.name=nameself.age=agedefmake_sound(self):"""动物发声(基类方法,由子类重写)"""return"动物发出声音"definfo(self):"""返回动物信息"""returnf"名称:{self.name},年龄:{self.age}岁,类型:{self.__class__.__name__}"classDog(Animal):"""狗类,继承自动物基类"""def__init__(self,name,age,breed="未知品种"):super().__init__(name,age)self.breed=breeddefmake_sound(self):"""狗的发声方法"""return"汪汪!"definfo(self):"""返回狗的信息,包括品种"""returnf"{super().info()},品种:{self.breed}"classCat(Animal):"""猫类,继承自动物基类"""defmake_sound(self):"""猫的发声方法"""return"喵喵!"defclimb_tree(self):"""猫的特有方法"""returnf"{self.name}正在爬树"defanimal_sound(animal):"""多态演示函数:传入动物对象,返回其声音"""returnanimal.make_sound()defcreate_animal(animal_type,name,age,**kwargs):"""工厂函数:根据类型创建动物对象"""animal_classes={'dog':Dog,'cat':Cat}ifanimal_type.lower()inanimal_classes:#对于Dog需要额外的breed参数ifanimal_type.lower()=='dog':breed=kwargs.get('breed','未知品种')returnDog(name,age,breed)else:returnanimal_classes[animal_type.lower()](name,age)else:#默认返回基类AnimalreturnAnimal(name,age)```##设计`main.py`使用`utils.py`模块```python#main.pyimportutilsdefmain():"""主函数,演示继承和多态"""#使用工厂函数创建动物animals=[utils.create_animal('dog','旺财',2,breed='金毛'),utils.create_animal('cat','咪咪',1),utils.create_animal('bird','小鸟',1)#不存在的类型,将返回基类Animal]#多态演示:每个动物发出声音print("=====多态演示:动物声音=====")foranimalinanimals:print(f"{animal.name}:{utils.animal_sound(animal)}")#显示动物信息print("\n=====动物信息=====")foranimalinanimals:print(animal.info())#尝试调用特有方法(注意:基类Animal没有climb_tree方法)print("\n=====特有方法演示=====")foranimalinanimals:ifisinstance(animal,utils.Cat):print(animal.climb_tree())elifisinstance(animal,utils.Dog):print(f"{animal.name}是一只{animal.breed},它会{animal.make_sound()}叫")if__name__=="__main__":main()```##运行结果示例```=====多态演示:动物声音=====旺财:汪汪!咪咪:喵喵!小鸟:动物发出声音=====动物信息=====名称:旺财,年龄:2岁,类型:Dog,品种:金毛名称:咪咪,年龄:1岁,类型:Cat名称:小鸟,年龄:1岁,类型:Animal=====特有方法演示=====咪咪正在爬树旺财是一只金毛,它会汪汪!叫```##代码说明1.**模块化设计**:将动物类和多态函数放在`utils.py`中,主程序在`main.py`中导入使用[^1][^2]。2.**工厂函数**:`create_animal`函数根据传入的动物类型字符串动态创建对应的动物对象,支持扩展(如添加新的动物类型只需更新工厂函数的映射字典)[^3]。3.**多态演示**:`animal_sound`函数接受任意动物对象,调用其`make_sound`方法,体现了多态性。4.**类型检查**:在`main.py`中,使用`isinstance`检查对象类型,以安全调用子类特有方法(如`climb_tree`)[^4]。##扩展建议1.若要添加新的动物类型(如`Bird`),只需在`utils.py`中添加一个新类并更新工厂函数的映射字典。2.可以使用抽象基类(ABC)来强制子类实现`make_sound`方法,确保多态的正确性。##相关问题1.如何在工厂函数中使用类方法替代?2.如何使用抽象基类(ABC)来确保子类实现了特定的方法?3.如何处理动物之间的互动(例如狗追猫)?4.如何将动物对象持久化到文件或数据库?</think># Python 继承与多态 utils.py 模块实现 下面是一个完整的 `utils.py` 模块实现,包含面向对象编程中的继承与多态特性,适合初学者理解和学习。 ## utils.py 模块代码 ```python # utils.py """ 动物模块 - 演示Python继承与多态 包含基类Animal及其子类Dog和Cat """ class Animal: """动物基类,定义所有动物的共同特征""" def __init__(self, name, age): """ 初始化动物实例 :param name: 动物名称 :param age: 动物年龄 """ self.name = name self.age = age def make_sound(self): """动物发出声音(基类方法,被子类重写)""" return "动物发出声音" def eat(self): """动物进食(基类方法)""" return f"{self.name}正在进食" def sleep(self): """动物睡觉(基类方法)""" return f"{self.name}正在睡觉" def info(self): """获取动物信息(基类方法)""" return f"名称: {self.name}, 年龄: {self.age}岁, 类型: {self.__class__.__name__}" class Dog(Animal): """狗类,继承自动物基类""" def __init__(self, name, age, breed): """ 初始化狗实例 :param breed: 狗的品种 """ super().__init__(name, age) # 调用父类构造方法[^1] self.breed = breed def make_sound(self): """重写make_sound方法(多态实现)[^1]""" return "汪汪!" def fetch(self): """狗特有的方法""" return f"{self.name}正在捡球" def info(self): """重写info方法,添加品种信息""" return f"{super().info()}, 品种: {self.breed}" class Cat(Animal): """猫类,继承自动物基类""" def make_sound(self): """重写make_sound方法""" return "喵喵!" def climb_tree(self): """猫特有的方法""" return f"{self.name}正在爬树" def animal_sound(animal): """ 多态演示函数:接受任何Animal类型对象并调用其make_sound方法 :param animal: Animal或其子类的实例 :return: 动物发出的声音 """ return animal.make_sound() def animal_info(animal): """ 获取动物信息的通用函数 :param animal: Animal或其子类的实例 :return: 动物信息字符串 """ return animal.info() def create_animal(animal_type, name, age, **kwargs): """ 动物工厂函数:根据类型创建不同的动物对象 :param animal_type: 动物类型 ('dog', 'cat', 'animal') :param name: 动物名称 :param age: 动物年龄 :param kwargs: 额外参数(如狗的品种) :return: Animal实例 """ if animal_type.lower() == 'dog': breed = kwargs.get('breed', '未知品种') return Dog(name, age, breed) elif animal_type.lower() == 'cat': return Cat(name, age) else: return Animal(name, age) ``` ## 使用示例 (main.py) ```python # main.py from utils import Animal, Dog, Cat, animal_sound, animal_info, create_animal def main(): """主函数,演示继承和多态""" # 创建动物实例 generic_animal = Animal("普通动物", 3) my_dog = Dog("旺财", 2, "金毛") my_cat = Cat("咪咪", 1) # 使用工厂函数创建动物 created_dog = create_animal('dog', '小黑', 3, breed='拉布拉多') created_cat = create_animal('cat', '小白', 4) # 演示继承 print("===== 继承演示 =====") print(generic_animal.eat()) # 调用基类方法 print(my_dog.sleep()) # 子类继承基类方法 print(my_cat.eat()) # 子类继承基类方法 # 演示多态性 print("\n===== 多态性演示 =====") animals = [ generic_animal, my_dog, my_cat, created_dog, created_cat ] for animal in animals: print(f"{animal.name}说: {animal_sound(animal)}") # 多态调用 # 使用通用信息函数 print("\n===== 动物信息 =====") for animal in animals: print(animal_info(animal)) # 多态调用 # 调用特有方法 print("\n===== 特有方法演示 =====") print(my_dog.fetch()) print(my_cat.climb_tree()) if __name__ == "__main__": main() ``` ## 运行结果示例 ``` ===== 继承演示 ===== 普通动物正在进食 旺财正在睡觉 咪咪正在进食 ===== 多态性演示 ===== 普通动物说: 动物发出声音 旺财说: 汪汪! 咪咪说: 喵喵! 小黑说: 汪汪! 小白说: 喵喵! ===== 动物信息 ===== 名称: 普通动物, 年龄: 3岁, 类型: Animal 名称: 旺财, 年龄: 2岁, 类型: Dog, 品种: 金毛 名称: 咪咪, 年龄: 1岁, 类型: Cat 名称: 小黑, 年龄: 3岁, 类型: Dog, 品种: 拉布拉多 名称: 小白, 年龄: 4岁, 类型: Cat ===== 特有方法演示 ===== 旺财正在捡球 咪咪正在爬树 ``` ## utils.py 模块功能详解 ### 1. 类结构 - **Animal类**:动物基类,定义所有动物的共同属性和方法 - **Dog类**:继承Animal,添加狗的特定行为(fetch)和属性(breed) - **Cat类**:继承Animal,添加猫的特定行为(climb_tree) ### 2. 多态实现 - **animal_sound()函数**:接受任何Animal类型对象,调用其make_sound方法[^1] - **animal_info()函数**:通用函数获取动物信息,通过多态调用不同子类的info方法 ### 3. 工厂模式 - **create_animal()函数**:根据类型创建不同的动物对象,实现对象创建的封装[^3] ### 4. 继承特性 - 子类继承父类的属性和方法(如eat(), sleep()) - 子类可以重写父类方法(如make_sound(), info()) - 使用super()调用父类方法实现代码复用[^1] ## 模块化优势 1. **代码复用**:将相关类封装在模块中,可在多个项目中复用[^1] 2. **封装性**:隐藏实现细节,只暴露必要的接口 3. **可扩展性**:轻松添加新动物类型(如Bird类)而不影响现有代码 4. **可维护性**:相关功能集中管理,便于维护和更新 ## 扩展建议 1. 添加更多动物类型: ```python class Bird(Animal): def make_sound(self): return "叽叽喳喳!" def fly(self): return f"{self.name}正在飞翔" ``` 2. 使用抽象基类确保一致性: ```python from abc import ABC, abstractmethod class Animal(ABC): @abstractmethod def make_sound(self): pass ``` 3. 添加动物互动功能: ```python def interact(animal1, animal2): """动物之间互动""" if isinstance(animal1, Dog) and isinstance(animal2, Cat): return f"{animal1.name}在追{animal2.name}!" # 添加更多互动逻辑 ```
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值