Python None comparison: should I use “is” or ==?

Python中is与==的区别
本文解释了在Python编程中如何正确使用is与==。is用于检查两个对象是否为同一个对象,而==则用于比较两个对象的内容是否相等。了解它们的区别对于避免编程错误至关重要。

Use is when you want to check against an object's identity (e.g. checking to see if var is None). Use == when you want to check equality (e.g. Is var equal to 3?).

 

 

ref: http://stackoverflow.com/questions/14247373/python-none-comparison-should-i-use-is-or

import numpy as np import einops from rp import * try: import torch except ImportError: pass __all__ = ["remove_watermark", "demo_remove_watermark"] def _is_uint8(x): if is_numpy_array (x): return x.dtype == np.uint8 elif is_torch_tensor(x): return x.dtype == torch.uint8 else: raise TypeError(f"Unsupported input type: {type(x)}") def _fft2(x): if is_numpy_array (x): return np.fft.fft2(x) elif is_torch_tensor(x): return torch.fft.fft2(x) else: raise TypeError(f"Unsupported input type: {type(x)}") def _ifft2(x): if is_numpy_array (x): return np.fft.ifft2(x) elif is_torch_tensor(x): return torch.fft.ifft2(x) else: raise TypeError(f"Unsupported input type: {type(x)}") def _fftshift(x): if is_numpy_array (x): return np.fft.fftshift(x) elif is_torch_tensor(x): return torch.fft.fftshift(x) else: raise TypeError(f"Unsupported input type: {type(x)}") def _clip(x, min_val, max_val): if is_numpy_array (x): return np.clip (x, min_val, max_val) elif is_torch_tensor(x): return torch.clamp(x, min_val, max_val) else: raise TypeError(f"Unsupported input type: {type(x)}") def _roll(x, shift, dims): if is_numpy_array (x): return np.roll(x, shift, axis=dims) elif is_torch_tensor(x): return torch.roll(x, shift, dims=dims) else: raise TypeError(f"Unsupported input type: {type(x)}") def _default_form(x): if is_numpy_array (x): return "THWC" elif is_torch_tensor(x): return "TCHW" else: raise TypeError(f"Unsupported input type: {type(x)}") def _like(x, target): if is_numpy_array (x) and is_numpy_array (target): return x elif is_torch_tensor(x) and is_torch_tensor(target): return x elif is_torch_tensor(x) and is_numpy_array (target): return as_numpy_array(x) elif is_numpy_array (x) and is_torch_tensor(target): return torch.tensor(x).to(target.device, target.dtype) else: raise TypeError(f"Unsupported input types: {type(x)} {type(target)}") @memoized def _get_watermark_image(): watermark_path = r"E:\watermark.exr" #watermark_path = with_file_name(__file__, "watermark.exr") watermark = load_image(watermark_path, use_cache=True) assert is_rgba_image(watermark), "Without alpha, the watermark is useless" assert is_float_image(watermark), "Watermark should ideally be saved with floating-point precision" return watermark def remove_watermark(video, form=None): """Removes watermark from a video. Given an RGB video as a THWC NumPy array in THWC form or TCHW PyTorch tensor, where T is num_frames, H and W are height and width, and 3 (channels) is for RGB. It assumes it's a watermarked video - matching the watermark found in watermark.exr (in the same folder as this python file). Currently, that watermark is for shutterstock videos - and is created with make_watermark_exr.py, also found in the same folder as this python file. Args: video: A NumPy array or PyTorch tensor representing the video frames in THW3 format. form (str, optional): If you want to use numpy videos in TCHW form or torch videos in THWC form, specify that. Valid options are 'TCHW' and 'THWC' Returns: A NumPy array or PyTorch tensor of the same shape and type as the input video, with the watermark removed, and floating point pixel values between 0 and 1. Notes: The function works by: 1. Convolving the RGBA watermark over the mean of all frames in grayscale to locate the watermark position. This uses FFT and IFFT for speed. (Technically uses cross-correlation) 2. Once the watermark shift is found, it does inverse alpha-blending to remove the watermark from all frames. The complexity is O(total num pixels in video) aka O(B * H * W). It is very fast and robust, even working on videos with the watermark upside-down. """ if form is None: form = _default_form(video) assert form in ['TCHW', 'THWC'] if form=='TCHW': video = einops.rearrange(video, 'T C H W -> T H W C') recovered = remove_watermark(video, form = 'THWC') recovered = einops.rearrange(recovered, 'T H W C -> T C H W') return recovered def recover_background(composite_images, rgba_watermark): # Extract RGB and Alpha components of the watermark watermark_rgb = rgba_watermark[:, :, :3] watermark_alpha = rgba_watermark[:, :, 3:] # Calculate the background image using the derived formula # Use _clip to ensure the resulting pixel values are still in the range [0, 1] background = (composite_images - watermark_alpha * watermark_rgb) / (1 - watermark_alpha) background = _clip(background, 0, 1) return background def get_shifts(): def cross_corr(img1, img2): assert is_a_matrix(img1) assert is_a_matrix(img2) # Compute the FFT of both images fft1 = _fft2(img1) fft2 = _fft2(img2) # Compute the cross-correlation in frequency domain cross_fft = fft1 * fft2.conj() # Compute the inverse FFT to get the cross-correlation in spatial domain cross_corr = _ifft2(cross_fft) # Shift the zero-frequency component to the center of the spectrum cross_corr = _fftshift(cross_corr) return cross_corr.real def best_shift(frame, watermark): # Compute the cross-correlation between frame and watermark corr = cross_corr(frame, watermark) # Find the coordinates of the maximum correlation max_loc = np.unravel_index(np.argmax(corr), corr.shape) # Compute the shift amounts dy, dx = ( max_loc[0] - watermark.shape[0] // 2, max_loc[1] - watermark.shape[1] // 2, ) return dx, dy #This function operates entirely in numpy. Don't worry, it's very fast! zavg_frame = as_numpy_array(avg_frame) zwatermark = as_numpy_array(watermark) zwatermark = blend_images(0.5, zwatermark) - 0.5 # Shape: H W C zavg_frame = zavg_frame - cv_gauss_blur(zavg_frame, sigma=20) # Shape: H W C zavg_frame = as_grayscale_image(zavg_frame) zwatermark = as_grayscale_image(zwatermark) return best_shift(zavg_frame, zwatermark) if _is_uint8(video): video = video / 255 avg_frame = video.mean(0) watermark = _get_watermark_image() # Make sure the watermark image is the same shape and type as the video so we can convolve them h, w, _ = avg_frame.shape watermark = crop_image(watermark, h, w) watermark = _like(watermark, avg_frame) best_watermark = None best_x_shift, best_y_shift = get_shifts() best_watermark = _roll(watermark, (best_y_shift, best_x_shift), dims=(0, 1)) recovered = recover_background(video, best_watermark) return recovered def demo_remove_watermark(input_video_glob="webvid/*.mp4", device=None): """Demonstrates the remove_watermark function on a set of videos. Applies remove_watermark to a set of videos specified by the given glob pattern, and saves comparison videos showing the original and watermark-removed versions to the 'comparison_videos/' directory. Args: input_video_glob: A glob pattern specifying the set of videos to process. Defaults to 'webvid/*.mp4'. device: If None, will use numpy. If a string like 'cpu' or 'cuda', will use torch. Notes: This demo function is fast enough to run on a typical laptop CPU. The processed videos are saved with filenames matching the input video names. """ test_videos = rp_glob(input_video_glob) test_videos = shuffled(test_videos) while test_videos: video_path = test_videos.pop() fansi_print("Loading video from " + video_path, "green", "bold") tic() video = load_video(video_path, use_cache=False) video = as_numpy_array(resize_list(video, length=60)) if device is not None: video = torch.tensor(video, device=device) fansi_print("Using pytorch on device = "+repr(device), 'green','bold') else: fansi_print("Using numpy on device = "+repr(device), 'magenta','bold') ptoctic() recovered = remove_watermark(video) ptoc() #For demo purposes we must convert it back to numpy arrays recovered = as_numpy_array(recovered) video = as_numpy_array(video) analy_video = vertically_concatenated_videos(recovered, video) fansi_print( "Saved video at " + save_video_mp4( analy_video, get_unique_copy_path( "comparison_videos/" + with_file_extension( get_file_name( video_path, include_file_extension=False, ), "mp4", ), ), framerate=30, ), "green", "bold", ) display_video(analy_video) if __name__ == "__main__": pip_import('fire').Fire(demo_remove_watermark) 给我通俗易懂的讲解以上代码的算法思想原理
最新发布
08-28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值