这是MatchedImageLoad类
class MatchedImageLoad:
# fixed image size for preprocessing and model input
_IMG_SIZE = 600
def __init__(self, rgb_dir, chip_id, setupid, klarf_file, gray_file, mindir_train, plot=False):
check_valid_file('config_reference.json', name='config_reference.json')
try:
self.config = Config.parse_file("config_reference.json")
except ValidationError as e:
raise ValueError(f"Validation for config_reference.json failed! "
f"The error is as follows:\n {e.json()}") from e
self.klarf_file = os.path.join(rgb_dir, klarf_file)
self.gray_file = os.path.join(rgb_dir, gray_file)
self.plot = plot
self.mindir_train = mindir_train
if self.plot:
self.plot_data = PlotData()
check_valid_dir(os.path.join(rgb_dir, "new_image"), name='new_image')
self.rgb_files = glob.glob(os.path.join(rgb_dir, "new_image/*.jpg"))
if not self.rgb_files:
raise ValueError('Error! No images found!')
self.jko = KlarfParser(self.klarf_file)
self.gray = imread_with_check(self.gray_file, name='gray_file', flags=cv2.IMREAD_GRAYSCALE)
log.info(f"Gray Image shape: {self.gray.shape}")
self.init_range_x, self.init_range_y, self.threshold, self.w_die, self.h_die, self.w_die_res, self.h_die_res, \
self.w_off, self.h_off, self.gray_size = get_rgb_range_thres(chip_id, setupid, self.config.range_lut)
log.info(f"dataset size is {len(self.rgb_files)}")
def __getitem__(self, index):
img_file = self.rgb_files[index]
rgb_img = imread_with_check(img_file, name=os.path.basename(img_file), flags=cv2.IMREAD_COLOR)
txt_info = self.jko.get_klarf(os.path.basename(img_file))
if not txt_info:
raise ValueError(f"Please check if the {os.path.basename(img_file)} in the klarf file.")
if len(txt_info) < 6:
raise ValueError('The length of txt_info must be longer than 5!')
x_rel = np.float32(txt_info[4 if self.mindir_train else 2])
y_rel = np.float32(txt_info[5 if self.mindir_train else 3])
if not self.w_die or not self.h_die:
raise ValueError('Params w_die or h_die can not be zero.')
x_center = int(round((x_rel / self.w_die) * self.w_die_res)) + self.w_off
y_center = int(round((self.h_die_res - (y_rel / self.h_die) * self.h_die_res))) + self.h_off
p = np.array(
[(self.init_range_y[1] + self.init_range_y[0]) / 2, (self.init_range_x[1] + self.init_range_x[0]) / 2],
dtype=np.int32)
q = np.array(
[(self.init_range_y[1] - self.init_range_y[0]) / 2, (self.init_range_x[1] - self.init_range_x[0]) / 2],
dtype=np.int32)
ths = self._IMG_SIZE // 2
gray_ref = cv2.getRectSubPix(self.gray, (self.gray_size * 2, self.gray_size * 2),
(int(x_center + p[1] - ths), int(y_center + p[0] - ths)))
gray_img = np.zeros((self._IMG_SIZE * 2, self._IMG_SIZE * 2))
rgb_img = cv2.resize(rgb_img, (self._IMG_SIZE, self._IMG_SIZE))
gray_img[ths:ths + self._IMG_SIZE, ths:ths + self._IMG_SIZE] = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.resize(gray_img, (self.gray_size * 2, self.gray_size * 2), cv2.INTER_NEAREST)
if self.plot:
image_name = os.path.basename(img_file)
self.plot_data.image_name = image_name
self.plot_data.inter_img = gray_img
self.plot_data.inter_ref = gray_ref
gray, rgb = self.compute_rgb_and_ray(gray_img, gray_ref, p, q, rgb_img, ths, x_center, y_center)
return gray, rgb
def __len__(self):
return len(self.rgb_files)
@staticmethod
def get_zncc(gray_img, gray_ref):
gray_ref[gray_img == 0] = 0
gray_img[gray_ref == 0] = 0
mi = np.mean(gray_img)
mt = np.mean(gray_ref)
try:
zncc = np.sum((gray_img - mi) * (gray_ref - mt)) / np.std(gray_img) / np.std(gray_ref) / gray_img.size
except ZeroDivisionError as e:
raise ZeroDivisionError('Divisor can not be zero.') from e
return zncc
def get_plot_data(self):
if not self.plot:
raise ValueError('The instance attribute "plot" must be True.')
return self.plot_data
def change_range(self, new_range_x, new_range_y):
self.init_range_x = new_range_x
self.init_range_y = new_range_y
def compute_rgb_and_ray(self, gray_img, gray_ref, p, q, rgb_img, ths, x_center, y_center):
shifts, corr, _ = register_translation(gray_img, gray_ref, upsample_factor=15, window_size=7, order=3,
max_shifts=q, space="real", hamming=True, use_scipy=True,
shifts_lb=None, shifts_ub=None, show=0, confidence=0)
center_y, center_x = p - shifts
ref_index = int(round((x_center - ths) + center_x)), int(round((y_center - ths) + center_y))
gray_ref = cv2.getRectSubPix(self.gray, (self.gray_size, self.gray_size), ref_index)
gray_ref_up = cv2.resize(gray_ref, (self._IMG_SIZE, self._IMG_SIZE), interpolation=cv2.INTER_CUBIC)
mutual_info = mutual_info_score(cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY).flatten(), gray_ref_up.flatten())
if mutual_info >= self.threshold:
mask_ref = np.ones((self._IMG_SIZE, self._IMG_SIZE))
else:
mask_ref = np.zeros((self._IMG_SIZE, self._IMG_SIZE))
if self.plot:
self.plot_data.final_img = rgb_img
self.plot_data.final_ref = gray_ref
self.plot_data.score = mutual_info
self.plot_data.is_matched = True if mutual_info >= self.threshold else False
mask = mask_ref.astype(np.float32)
rgb = (np.transpose(rgb_img, (2, 0, 1)) / 255. * 2 - 1).astype(np.float32) * mask
gray = (np.expand_dims(gray_ref_up, 0) / 255. * 2 - 1).astype(np.float32) * mask
return gray, rgb
最新发布