Python3-scipy

本文介绍了Scipy库中用于高级科学计算的各种子模块及其应用示例,包括文件输入输出、特殊函数、线性代数运算、插值、优化与拟合、统计与随机数生成等内容。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

参考:

1、http://www.scipy-lectures.org/index.html

2、https://github.com/scipy-lectures/scipy-lecture-notes/tree/master/intro/scipy/examples

3、https://github.com/jayleicn/scipy-lecture-notes-zh-CN

4、http://scipy-cookbook.readthedocs.io/

5、https://docs.scipy.org/doc/scipy/reference/


1.5. Scipy : high-level scientific computing

Chapters contents

scipy is composed of task-specific sub-modules:

scipy.cluster Vector quantization / Kmeans
scipy.constants Physical and mathematical constants
scipy.fftpack Fourier transform
scipy.integrate Integration routines
scipy.interpolate Interpolation
scipy.io Data input and output
scipy.linalg Linear algebra routines
scipy.ndimage n-dimensional image package
scipy.odr Orthogonal distance regression
scipy.optimize Optimization
scipy.signal Signal processing
scipy.sparse Sparse matrices
scipy.spatial Spatial data structures and algorithms
scipy.special Any special mathematical functions
scipy.stats Statistics


1.5.1 File input/output: scipy.io

Matlab files: Loading and saving:

>>>
from scipy import io as spio
a = np.ones((3, 3))
spio.savemat('file.mat', {'a': a}) # savemat expects a dictionary
data = spio.loadmat('file.mat')
data['a']


Image files: Reading images:

>>>
>>> from scipy import misc
>>> misc.imread('fname.png')    
array(...)
>>> # Matplotlib also has a similar function
>>> import matplotlib.pyplot as plt
>>> plt.imread('fname.png')    
array(...)

See also


1.5.3 Linear algebra operations: scipy.linalg

  • The scipy.linalg.det() function computes the determinant of a square matrix:

    >>>
    >>> from scipy import linalg
    >>> arr = np.array([[1, 2],
    ...                 [3, 4]])
    >>> linalg.det(arr)
    -2.0
    >>> arr = np.array([[3, 2],
    ...                 [6, 4]])
    >>> linalg.det(arr) 
    0.0
    >>> linalg.det(np.ones((3, 4)))
    Traceback (most recent call last):
    ...
    ValueError: expected square matrix
    
  • The scipy.linalg.inv() function computes the inverse of a square matrix:

    >>>
    >>> arr = np.array([[1, 2],
    ...                 [3, 4]])
    >>> iarr = linalg.inv(arr)
    >>> iarr
    array([[-2. ,  1. ],
           [ 1.5, -0.5]])
    >>> np.allclose(np.dot(arr, iarr), np.eye(2))
    True
    

    Finally computing the inverse of a singular matrix (its determinant is zero) will raise LinAlgError:

    >>>
    >>> arr = np.array([[3, 2],
    ...                 [6, 4]])
    >>> linalg.inv(arr)  
    Traceback (most recent call last):
    ...
    ...LinAlgError: singular matrix
    
  • More advanced operations are available, for example singular-value decomposition (SVD):

    >>>
    >>> arr = np.arange(9).reshape((3, 3)) + np.diag([1, 0, 1])
    >>> uarr, spec, vharr = linalg.svd(arr)
    

    The resulting array spectrum is:

    >>>
    >>> spec    
    array([ 14.88982544,   0.45294236,   0.29654967])
    

    The original matrix can be re-composed by matrix multiplication of the outputs of svd with np.dot:

    >>>
    >>> sarr = np.diag(spec)
    >>> svd_mat = uarr.dot(sarr).dot(vharr)
    >>> np.allclose(svd_mat, arr)
    True
    

    SVD is commonly used in statistics and signal processing. Many other standard decompositions (QR, LU, Cholesky, Schur), as well as solvers for linear systems, are available in scipy.linalg.

其他略过

1.5.12 Full code examples for the scipy chapter


1.5.12.1 Finding the minimum of a smooth function

Demos various methods to find the minimum of a function.

import numpy as np
import matplotlib.pyplot as plt

def f(x):
    return x**2 + 10*np.sin(x)


x = np.arange(-10, 10, 0.1)
plt.plot(x, f(x))
../../../_images/sphx_glr_plot_optimize_example1_001.png

Now find the minimum with a few methods

from scipy import optimize

# The default (Nelder Mead)
print(optimize.minimize(f, x0=0))

Out:

fun: -7.945823375615215
 hess_inv: array([[ 0.08589237]])
      jac: array([ -1.19209290e-06])
  message: 'Optimization terminated successfully.'
     nfev: 18
      nit: 5
     njev: 6
   status: 0
  success: True
        x: array([-1.30644012])
print(optimize.minimize(f, x0=0, method="L-BFGS-B"))

Out:

fun: array([-7.94582338])
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([ -1.42108547e-06])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 12
      nit: 5
   status: 0
  success: True
        x: array([-1.30644013])
plt.show()

Total running time of the script: ( 0 minutes 0.103 seconds)

"""
=========================================
Finding the minimum of a smooth function
=========================================
Demos various methods to find the minimum of a function.
"""

import numpy as np
import matplotlib.pyplot as plt

def f(x):
    return x**2 + 10*np.sin(x)


x = np.arange(-10, 10, 0.1)
plt.plot(x, f(x))

############################################################
# Now find the minimum with a few methods
from scipy import optimize

# The default (Nelder Mead)
print(optimize.minimize(f, x0=0))

############################################################
print(optimize.minimize(f, x0=0, method="L-BFGS-B"))

############################################################

plt.show()


1.5.12.2 Detrending a signal

scipy.signal.detrend() removes a linear trend.

Generate a random signal with a trend

import numpy as np
t = np.linspace(0, 5, 100)
x = t + np.random.normal(size=100)

Detrend

from scipy import signal
x_detrended = signal.detrend(x)

Plot

from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label="x")
plt.plot(t, x_detrended, label="x_detrended")
plt.legend(loc='best')
plt.show()
../../../_images/sphx_glr_plot_detrend_001.png

Total running time of the script: ( 0 minutes 0.187 seconds)


"""
===================
Detrending a signal
===================
:func:`scipy.signal.detrend` removes a linear trend.
"""

############################################################
# Generate a random signal with a trend
import numpy as np
t = np.linspace(0, 5, 100)
x = t + np.random.normal(size=100)

############################################################
# Detrend
from scipy import signal
x_detrended = signal.detrend(x)

############################################################
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label="x")
plt.plot(t, x_detrended, label="x_detrended")
plt.legend(loc='best')
plt.show()

1.5.12.3 Resample a signal with scipy.signal.resample

scipy.signal.resample() uses FFT to resample a 1D signal.

Generate a signal with 100 data point

import numpy as np
t = np.linspace(0, 5, 100)
x = np.sin(t)

Downsample it by a factor of 4

from scipy import signal
x_resampled = signal.resample(x, 25)

Plot

from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label='Original signal')
plt.plot(t[::4], x_resampled, 'ko', label='Resampled signal')

plt.legend(loc='best')
plt.show()
../../../_images/sphx_glr_plot_resample_001.png

Total running time of the script: ( 0 minutes 0.037 seconds)

"""
Resample a signal with scipy.signal.resample
=============================================
:func:`scipy.signal.resample` uses FFT to resample a 1D signal.
"""

############################################################
# Generate a signal with 100 data point
import numpy as np
t = np.linspace(0, 5, 100)
x = np.sin(t)

############################################################
# Downsample it by a factor of 4
from scipy import signal
x_resampled = signal.resample(x, 25)

############################################################
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label='Original signal')
plt.plot(t[::4], x_resampled, 'ko', label='Resampled signal')

plt.legend(loc='best')
plt.show()

1.5.12.4 Integrating a simple ODE

Solve the ODE dy/dt = -2y between t = 0..4, with the initial condition y(t=0) = 1.

../../../_images/sphx_glr_plot_odeint_simple_001.png
import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt

def calc_derivative(ypos, time):
    return -2*ypos

time_vec = np.linspace(0, 4, 40)
yvec = odeint(calc_derivative, 1, time_vec)

plt.figure(figsize=(4, 3))
plt.plot(time_vec, yvec)
plt.xlabel('t: Time')
plt.ylabel('y: Position')
plt.tight_layout()

Total running time of the script: ( 0 minutes 0.056 seconds)

"""
=========================
Integrating a simple ODE
=========================
Solve the ODE dy/dt = -2y between t = 0..4, with the initial condition
y(t=0) = 1.
"""

import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt

def calc_derivative(ypos, time):
    return -2*ypos

time_vec = np.linspace(0, 4, 40)
yvec = odeint(calc_derivative, 1, time_vec)

plt.figure(figsize=(4, 3))
plt.plot(time_vec, yvec)
plt.xlabel('t: Time')
plt.ylabel('y: Position')
plt.tight_layout()

1.5.12.5 Comparing 2 sets of samples from Gaussians

../../../_images/sphx_glr_plot_t_test_001.png

import numpy as np
from matplotlib import pyplot as plt

# Generates 2 sets of observations
samples1 = np.random.normal(0, size=1000)
samples2 = np.random.normal(1, size=1000)

# Compute a histogram of the sample
bins = np.linspace(-4, 4, 30)
histogram1, bins = np.histogram(samples1, bins=bins, normed=True)
histogram2, bins = np.histogram(samples2, bins=bins, normed=True)

plt.figure(figsize=(6, 4))
plt.hist(samples1, bins=bins, normed=True, label="Samples 1")
plt.hist(samples2, bins=bins, normed=True, label="Samples 2")
plt.legend(loc='best')
plt.show()

Total running time of the script: ( 0 minutes 0.078 seconds)

"""
==========================================
Comparing 2 sets of samples from Gaussians
==========================================
"""

import numpy as np
from matplotlib import pyplot as plt

# Generates 2 sets of observations
samples1 = np.random.normal(0, size=1000)
samples2 = np.random.normal(1, size=1000)

# Compute a histogram of the sample
bins = np.linspace(-4, 4, 30)
histogram1, bins = np.histogram(samples1, bins=bins, normed=True)
histogram2, bins = np.histogram(samples2, bins=bins, normed=True)

plt.figure(figsize=(6, 4))
plt.hist(samples1, bins=bins, normed=True, label="Samples 1")
plt.hist(samples2, bins=bins, normed=True, label="Samples 2")
plt.legend(loc='best')
plt.show()


1.5.12.6 Integrate the Damped spring-mass oscillator

../../../_images/sphx_glr_plot_odeint_damped_spring_mass_001.png

import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt

mass = 0.5  # kg
kspring = 4  # N/m
cviscous = 0.4  # N s/m


eps = cviscous / (2 * mass * np.sqrt(kspring/mass))
omega = np.sqrt(kspring / mass)


def calc_deri(yvec, time, eps, omega):
    return (yvec[1], -eps * omega * yvec[1] - omega **2 * yvec[0])

time_vec = np.linspace(0, 10, 100)
yinit = (1, 0)
yarr = odeint(calc_deri, yinit, time_vec, args=(eps, omega))

plt.figure(figsize=(4, 3))
plt.plot(time_vec, yarr[:, 0], label='y')
plt.plot(time_vec, yarr[:, 1], label="y'")
plt.legend(loc='best')
plt.show()

Total running time of the script: ( 0 minutes 0.038 seconds)

"""
============================================
Integrate the Damped spring-mass oscillator
============================================
"""

import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt

mass = 0.5  # kg
kspring = 4  # N/m
cviscous = 0.4  # N s/m


eps = cviscous / (2 * mass * np.sqrt(kspring/mass))
omega = np.sqrt(kspring / mass)


def calc_deri(yvec, time, eps, omega):
    return (yvec[1], -eps * omega * yvec[1] - omega **2 * yvec[0])

time_vec = np.linspace(0, 10, 100)
yinit = (1, 0)
yarr = odeint(calc_deri, yinit, time_vec, args=(eps, omega))

plt.figure(figsize=(4, 3))
plt.plot(time_vec, yarr[:, 0], label='y')
plt.plot(time_vec, yarr[:, 1], label="y'")
plt.legend(loc='best')
plt.show()


1.5.12.7 Normal distribution: histogram and PDF

Explore the normal distribution: a histogram built from samples and the PDF (probability density function).

../../../_images/sphx_glr_plot_normal_distribution_001.png
import numpy as np

# Sample from a normal distribution using numpy's random number generator
samples = np.random.normal(size=10000)

# Compute a histogram of the sample
bins = np.linspace(-5, 5, 30)
histogram, bins = np.histogram(samples, bins=bins, normed=True)

bin_centers = 0.5*(bins[1:] + bins[:-1])

# Compute the PDF on the bin centers from scipy distribution object
from scipy import stats
pdf = stats.norm.pdf(bin_centers)

from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(bin_centers, histogram, label="Histogram of samples")
plt.plot(bin_centers, pdf, label="PDF")
plt.legend()
plt.show()

Total running time of the script: ( 0 minutes 0.037 seconds)

"""
=======================================
Normal distribution: histogram and PDF
=======================================
Explore the normal distribution: a histogram built from samples and the
PDF (probability density function).
"""

import numpy as np

# Sample from a normal distribution using numpy's random number generator
samples = np.random.normal(size=10000)

# Compute a histogram of the sample
bins = np.linspace(-5, 5, 30)
histogram, bins = np.histogram(samples, bins=bins, normed=True)

bin_centers = 0.5*(bins[1:] + bins[:-1])

# Compute the PDF on the bin centers from scipy distribution object
from scipy import stats
pdf = stats.norm.pdf(bin_centers)

from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(bin_centers, histogram, label="Histogram of samples")
plt.plot(bin_centers, pdf, label="PDF")
plt.legend()
plt.show()

1.5.12.8 Curve fitting

Demos a simple curve fitting

First generate some data

import numpy as np

# Seed the random number generator for reproducibility
np.random.seed(0)

x_data = np.linspace(-5, 5, num=50)
y_data = 2.9 * np.sin(1.5 * x_data) + np.random.normal(size=50)

# And plot it
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data)
../../../_images/sphx_glr_plot_curve_fit_0011.png

Now fit a simple sine function to the data

from scipy import optimize

def test_func(x, a, b):
    return a * np.sin(b * x)

params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
                                               p0=[2, 2])

print(params)

Out:

[ 3.05931973  1.45754553]

And plot the resulting curve on the data

plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1]),
         label='Fitted function')

plt.legend(loc='best')

plt.show()
../../../_images/sphx_glr_plot_curve_fit_002.png

Total running time of the script: ( 0 minutes 0.071 seconds)

"""
===============
Curve fitting
===============
Demos a simple curve fitting
"""

############################################################
# First generate some data
import numpy as np

# Seed the random number generator for reproducibility
np.random.seed(0)

x_data = np.linspace(-5, 5, num=50)
y_data = 2.9 * np.sin(1.5 * x_data) + np.random.normal(size=50)

# And plot it
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data)

############################################################
# Now fit a simple sine function to the data
from scipy import optimize

def test_func(x, a, b):
    return a * np.sin(b * x)

params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
                                               p0=[2, 2])

print(params)

############################################################
# And plot the resulting curve on the data

plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1]),
         label='Fitted function')

plt.legend(loc='best')

plt.show()


1.5.12.9 Spectrogram, power spectral density

Demo spectrogram and power spectral density on a frequency chirp.

import numpy as np
from matplotlib import pyplot as plt


Generate a chirp signal

# Seed the random number generator
np.random.seed(0)

time_step = .01
time_vec = np.arange(0, 70, time_step)

# A signal with a small frequency chirp
sig = np.sin(0.5 * np.pi * time_vec * (1 + .1 * time_vec))

plt.figure(figsize=(8, 5))
plt.plot(time_vec, sig)
../../../_images/sphx_glr_plot_spectrogram_001.png


Compute and plot the spectrogram


The spectrum of the signal on consecutive time windows
from scipy import signal
freqs, times, spectrogram = signal.spectrogram(sig)

plt.figure(figsize=(5, 4))
plt.imshow(spectrogram, aspect='auto', cmap='hot_r', origin='lower')
plt.title('Spectrogram')
plt.ylabel('Frequency band')
plt.xlabel('Time window')
plt.tight_layout()
../../../_images/sphx_glr_plot_spectrogram_002.png


Compute and plot the power spectral density (PSD)


The power of the signal per frequency band
freqs, psd = signal.welch(sig)

plt.figure(figsize=(5, 4))
plt.semilogx(freqs, psd)
plt.title('PSD: power spectral density')
plt.xlabel('Frequency')
plt.ylabel('Power')
plt.tight_layout()
../../../_images/sphx_glr_plot_spectrogram_003.png
plt.show()

Total running time of the script: ( 0 minutes 0.303 seconds)

"""
======================================
Spectrogram, power spectral density
======================================
Demo spectrogram and power spectral density on a frequency chirp.
"""

import numpy as np
from matplotlib import pyplot as plt

############################################################
# Generate a chirp signal
############################################################

# Seed the random number generator
np.random.seed(0)

time_step = .01
time_vec = np.arange(0, 70, time_step)

# A signal with a small frequency chirp
sig = np.sin(0.5 * np.pi * time_vec * (1 + .1 * time_vec))

plt.figure(figsize=(8, 5))
plt.plot(time_vec, sig)

############################################################
# Compute and plot the spectrogram
############################################################
#
# The spectrum of the signal on consecutive time windows

from scipy import signal
freqs, times, spectrogram = signal.spectrogram(sig)

plt.figure(figsize=(5, 4))
plt.imshow(spectrogram, aspect='auto', cmap='hot_r', origin='lower')
plt.title('Spectrogram')
plt.ylabel('Frequency band')
plt.xlabel('Time window')
plt.tight_layout()


############################################################
# Compute and plot the power spectral density (PSD)
############################################################
#
# The power of the signal per frequency band

freqs, psd = signal.welch(sig)

plt.figure(figsize=(5, 4))
plt.semilogx(freqs, psd)
plt.title('PSD: power spectral density')
plt.xlabel('Frequency')
plt.ylabel('Power')
plt.tight_layout()

############################################################

plt.show()

1.5.12.10 A demo of 1D interpolation

../../../_images/sphx_glr_plot_interpolation_001.png

# Generate data
import numpy as np
np.random.seed(0)
measured_time = np.linspace(0, 1, 10)
noise = 1e-1 * (np.random.random(10)*2 - 1)
measures = np.sin(2 * np.pi * measured_time) + noise

# Interpolate it to new time points
from scipy.interpolate import interp1d
linear_interp = interp1d(measured_time, measures)
interpolation_time = np.linspace(0, 1, 50)
linear_results = linear_interp(interpolation_time)
cubic_interp = interp1d(measured_time, measures, kind='cubic')
cubic_results = cubic_interp(interpolation_time)

# Plot the data and the interpolation
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(measured_time, measures, 'o', ms=6, label='measures')
plt.plot(interpolation_time, linear_results, label='linear interp')
plt.plot(interpolation_time, cubic_results, label='cubic interp')
plt.legend()
plt.show()

Total running time of the script: ( 0 minutes 0.039 seconds)

"""
============================
A demo of 1D interpolation
============================
"""

# Generate data
import numpy as np
np.random.seed(0)
measured_time = np.linspace(0, 1, 10)
noise = 1e-1 * (np.random.random(10)*2 - 1)
measures = np.sin(2 * np.pi * measured_time) + noise

# Interpolate it to new time points
from scipy.interpolate import interp1d
linear_interp = interp1d(measured_time, measures)
interpolation_time = np.linspace(0, 1, 50)
linear_results = linear_interp(interpolation_time)
cubic_interp = interp1d(measured_time, measures, kind='cubic')
cubic_results = cubic_interp(interpolation_time)

# Plot the data and the interpolation
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(measured_time, measures, 'o', ms=6, label='measures')
plt.plot(interpolation_time, linear_results, label='linear interp')
plt.plot(interpolation_time, cubic_results, label='cubic interp')
plt.legend()
plt.show()

1.5.12.11 Demo mathematical morphology

A basic demo of binary opening and closing.

../../../_images/sphx_glr_plot_mathematical_morpho_001.png
# Generate some binary data
import numpy as np
np.random.seed(0)
a = np.zeros((50, 50))
a[10:-10, 10:-10] = 1
a += 0.25 * np.random.standard_normal(a.shape)
mask = a>=0.5

# Apply mathematical morphology
from scipy import ndimage
opened_mask = ndimage.binary_opening(mask)
closed_mask = ndimage.binary_closing(opened_mask)

# Plot
from matplotlib import pyplot as plt

plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(a, cmap=plt.cm.gray)
plt.axis('off')
plt.title('a')

plt.subplot(142)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')

plt.subplot(143)
plt.imshow(opened_mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('opened_mask')

plt.subplot(144)
plt.imshow(closed_mask, cmap=plt.cm.gray)
plt.title('closed_mask')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)

plt.show()

Total running time of the script: ( 0 minutes 0.133 seconds)

"""
=============================
Demo mathematical morphology
=============================
A basic demo of binary opening and closing.
"""

# Generate some binary data
import numpy as np
np.random.seed(0)
a = np.zeros((50, 50))
a[10:-10, 10:-10] = 1
a += 0.25 * np.random.standard_normal(a.shape)
mask = a>=0.5

# Apply mathematical morphology
from scipy import ndimage
opened_mask = ndimage.binary_opening(mask)
closed_mask = ndimage.binary_closing(opened_mask)

# Plot
from matplotlib import pyplot as plt

plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(a, cmap=plt.cm.gray)
plt.axis('off')
plt.title('a')

plt.subplot(142)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')

plt.subplot(143)
plt.imshow(opened_mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('opened_mask')

plt.subplot(144)
plt.imshow(closed_mask, cmap=plt.cm.gray)
plt.title('closed_mask')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)

plt.show()

1.5.12.12 Plot geometrical transformations on images

Demo geometrical transformations of images.

../../../_images/sphx_glr_plot_image_transform_001.png
# Load some data
from scipy import misc
face = misc.face(gray=True)

# Apply a variety of transformations
from scipy import ndimage
from matplotlib import pyplot as plt
shifted_face = ndimage.shift(face, (50, 50))
shifted_face2 = ndimage.shift(face, (50, 50), mode='nearest')
rotated_face = ndimage.rotate(face, 30)
cropped_face = face[50:-50, 50:-50]
zoomed_face = ndimage.zoom(face, 2)
zoomed_face.shape

plt.figure(figsize=(15, 3))
plt.subplot(151)
plt.imshow(shifted_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(152)
plt.imshow(shifted_face2, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(153)
plt.imshow(rotated_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(154)
plt.imshow(cropped_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(155)
plt.imshow(zoomed_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)

plt.show()

Total running time of the script: ( 0 minutes 0.916 seconds)

"""
============================================
Plot geometrical transformations on images
============================================
Demo geometrical transformations of images.
"""

# Load some data
from scipy import misc
face = misc.face(gray=True)

# Apply a variety of transformations
from scipy import ndimage
from matplotlib import pyplot as plt
shifted_face = ndimage.shift(face, (50, 50))
shifted_face2 = ndimage.shift(face, (50, 50), mode='nearest')
rotated_face = ndimage.rotate(face, 30)
cropped_face = face[50:-50, 50:-50]
zoomed_face = ndimage.zoom(face, 2)
zoomed_face.shape

plt.figure(figsize=(15, 3))
plt.subplot(151)
plt.imshow(shifted_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(152)
plt.imshow(shifted_face2, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(153)
plt.imshow(rotated_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(154)
plt.imshow(cropped_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplot(155)
plt.imshow(zoomed_face, cmap=plt.cm.gray)
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)

plt.show()

1.5.12.13 Demo connected components

Extracting and labeling connected components in a 2D array

import numpy as np
from matplotlib import pyplot as plt

Generate some binary data

np.random.seed(0)
x, y = np.indices((100, 100))
sig = np.sin(2*np.pi*x/50.) * np.sin(2*np.pi*y/50.) * (1+x*y/50.**2)**2
mask = sig > 1

plt.figure(figsize=(7, 3.5))
plt.subplot(1, 2, 1)
plt.imshow(sig)
plt.axis('off')
plt.title('sig')

plt.subplot(1, 2, 2)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
../../../_images/sphx_glr_plot_connect_measurements_001.png

Label connected components

from scipy import ndimage
labels, nb = ndimage.label(mask)

plt.figure(figsize=(3.5, 3.5))
plt.imshow(labels)
plt.title('label')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
../../../_images/sphx_glr_plot_connect_measurements_002.png

Extract the 4th connected component, and crop the array around it

sl = ndimage.find_objects(labels==4)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(sig[sl[0]])
plt.title('Cropped connected component')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)

plt.show()
../../../_images/sphx_glr_plot_connect_measurements_003.png

Total running time of the script: ( 0 minutes 0.138 seconds)

"""
=============================
Demo connected components
=============================
Extracting and labeling connected components in a 2D array
"""

import numpy as np
from matplotlib import pyplot as plt

############################################################
# Generate some binary data
np.random.seed(0)
x, y = np.indices((100, 100))
sig = np.sin(2*np.pi*x/50.) * np.sin(2*np.pi*y/50.) * (1+x*y/50.**2)**2
mask = sig > 1

plt.figure(figsize=(7, 3.5))
plt.subplot(1, 2, 1)
plt.imshow(sig)
plt.axis('off')
plt.title('sig')

plt.subplot(1, 2, 2)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)


############################################################
# Label connected components
from scipy import ndimage
labels, nb = ndimage.label(mask)

plt.figure(figsize=(3.5, 3.5))
plt.imshow(labels)
plt.title('label')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)


############################################################
# Extract the 4th connected component, and crop the array around it
sl = ndimage.find_objects(labels==4)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(sig[sl[0]])
plt.title('Cropped connected component')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)

plt.show()


1.5.12.14 Minima and roots of a function

Demos finding minima and roots of a function.

import numpy as np

x = np.arange(-10, 10, 0.1)
def f(x):
    return x**2 + 10*np.sin(x)


Find minima

from scipy import optimize

# Global optimization
grid = (-10, 10, 0.1)
xmin_global = optimize.brute(f, (grid, ))
print("Global minima found %s" % xmin_global)

# Constrain optimization
xmin_local = optimize.fminbound(f, 0, 10)
print("Local minimum found %s" % xmin_local)

Out:

Global minima found [-1.30641113]
Local minimum found 3.8374671195


Root finding

root = optimize.root(f, 1)  # our initial guess is 1
print("First root found %s" % root.x)
root2 = optimize.root(f, -2.5)
print("Second root found %s" % root2.x)

Out:

First root found [ 0.]
Second root found [-2.47948183]


Plot function, minima, and roots

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)

# Plot the function
ax.plot(x, f(x), 'b-', label="f(x)")

# Plot the minima
xmins = np.array([xmin_global[0], xmin_local])
ax.plot(xmins, f(xmins), 'go', label="Minima")

# Plot the roots
roots = np.array([root.x, root2.x])
ax.plot(roots, f(roots), 'kv', label="Roots")

# Decorate the figure
ax.legend(loc='best')
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.axhline(0, color='gray')
plt.show()
../../../_images/sphx_glr_plot_optimize_example2_001.png

Total running time of the script: ( 0 minutes 0.043 seconds)

"""
===============================
Minima and roots of a function
===============================
Demos finding minima and roots of a function.
"""

############################################################
# Define the function
############################################################

import numpy as np

x = np.arange(-10, 10, 0.1)
def f(x):
    return x**2 + 10*np.sin(x)


############################################################
# Find minima
############################################################

from scipy import optimize

# Global optimization
grid = (-10, 10, 0.1)
xmin_global = optimize.brute(f, (grid, ))
print("Global minima found %s" % xmin_global)

# Constrain optimization
xmin_local = optimize.fminbound(f, 0, 10)
print("Local minimum found %s" % xmin_local)

############################################################
# Root finding
############################################################

root = optimize.root(f, 1)  # our initial guess is 1
print("First root found %s" % root.x)
root2 = optimize.root(f, -2.5)
print("Second root found %s" % root2.x)

############################################################
# Plot function, minima, and roots
############################################################

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)

# Plot the function
ax.plot(x, f(x), 'b-', label="f(x)")

# Plot the minima
xmins = np.array([xmin_global[0], xmin_local])
ax.plot(xmins, f(xmins), 'go', label="Minima")

# Plot the roots
roots = np.array([root.x, root2.x])
ax.plot(roots, f(roots), 'kv', label="Roots")

# Decorate the figure
ax.legend(loc='best')
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.axhline(0, color='gray')
plt.show()

1.5.12.15 Optimization of a two-parameter function

import numpy as np


# Define the function that we are interested in
def sixhump(x):
    return ((4 - 2.1*x[0]**2 + x[0]**4 / 3.) * x[0]**2 + x[0] * x[1]
            + (-4 + 4*x[1]**2) * x[1] **2)

# Make a grid to evaluate the function (for plotting)
x = np.linspace(-2, 2)
y = np.linspace(-1, 1)
xg, yg = np.meshgrid(x, y)


A 2D image plot of the function


Simple visualization in 2D
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
../../../_images/sphx_glr_plot_2d_minimization_001.png


A 3D surface plot of the function

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(xg, yg, sixhump([xg, yg]), rstride=1, cstride=1,
                       cmap=plt.cm.jet, linewidth=0, antialiased=False)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f(x, y)')
ax.set_title('Six-hump Camelback function')
../../../_images/sphx_glr_plot_2d_minimization_002.png

Find the minima


from scipy import optimize

x_min = optimize.minimize(sixhump, x0=[0, 0])

plt.figure()
# Show the function in 2D
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
# And the minimum that we've found:
plt.scatter(x_min.x[0], x_min.x[1])

plt.show()
../../../_images/sphx_glr_plot_2d_minimization_003.png

Total running time of the script: ( 0 minutes 0.236 seconds)

"""
=========================================
Optimization of a two-parameter function
=========================================
"""

import numpy as np


# Define the function that we are interested in
def sixhump(x):
    return ((4 - 2.1*x[0]**2 + x[0]**4 / 3.) * x[0]**2 + x[0] * x[1]
            + (-4 + 4*x[1]**2) * x[1] **2)

# Make a grid to evaluate the function (for plotting)
x = np.linspace(-2, 2)
y = np.linspace(-1, 1)
xg, yg = np.meshgrid(x, y)

############################################################
# A 2D image plot of the function
############################################################
# Simple visualization in 2D
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()

############################################################
# A 3D surface plot of the function
############################################################
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(xg, yg, sixhump([xg, yg]), rstride=1, cstride=1,
                       cmap=plt.cm.jet, linewidth=0, antialiased=False)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f(x, y)')
ax.set_title('Six-hump Camelback function')

############################################################
# Find the minima
############################################################
from scipy import optimize

x_min = optimize.minimize(sixhump, x0=[0, 0])

plt.figure()
# Show the function in 2D
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
# And the minimum that we've found:
plt.scatter(x_min.x[0], x_min.x[1])

plt.show()

1.5.12.16 Plot filtering on images

Demo filtering for denoising of images.

../../../_images/sphx_glr_plot_image_filters_001.png
# Load some data
from scipy import misc
face = misc.face(gray=True)
face = face[:512, -512:]  # crop out square on right

# Apply a variety of filters
from scipy import ndimage
from scipy import signal
from matplotlib import pyplot as plt

import numpy as np
noisy_face = np.copy(face).astype(np.float)
noisy_face += face.std() * 0.5 * np.random.standard_normal(face.shape)
blurred_face = ndimage.gaussian_filter(noisy_face, sigma=3)
median_face = ndimage.median_filter(noisy_face, size=5)
wiener_face = signal.wiener(noisy_face, (5, 5))

plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(noisy_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('noisy')

plt.subplot(142)
plt.imshow(blurred_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('Gaussian filter')

plt.subplot(143)
plt.imshow(median_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('median filter')

plt.subplot(144)
plt.imshow(wiener_face, cmap=plt.cm.gray)
plt.title('Wiener filter')
plt.axis('off')

plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)

plt.show()

Total running time of the script: ( 0 minutes 0.420 seconds)

1.5.12.17 Plotting and manipulating FFTs for filtering

Plot the power of the FFT of a signal and inverse FFT back to reconstruct a signal.

This example demonstrate scipy.fftpack.fft()scipy.fftpack.fftfreq()and scipy.fftpack.ifft(). It implements a basic filter that is very suboptimal, and should not be used.

import numpy as np
from scipy import fftpack
from matplotlib import pyplot as plt



Generate the signal

# Seed the random number generator
np.random.seed(1234)

time_step = 0.02
period = 5.

time_vec = np.arange(0, 20, time_step)
sig = (np.sin(2 * np.pi / period * time_vec)
       + 0.5 * np.random.randn(time_vec.size))

plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
../../../_images/sphx_glr_plot_fftpack_001.png


Compute and plot the power

# The FFT of the signal
sig_fft = fftpack.fft(sig)

# And the power (sig_fft is of complex dtype)
power = np.abs(sig_fft)

# The corresponding frequencies
sample_freq = fftpack.fftfreq(sig.size, d=time_step)

# Plot the FFT power
plt.figure(figsize=(6, 5))
plt.plot(sample_freq, power)
plt.xlabel('Frequency [Hz]')
plt.ylabel('plower')

# Find the peak frequency: we can focus on only the positive frequencies
pos_mask = np.where(sample_freq > 0)
freqs = sample_freq[pos_mask]
peak_freq = freqs[power[pos_mask].argmax()]

# Check that it does indeed correspond to the frequency that we generate
# the signal with
np.allclose(peak_freq, 1./period)

# An inner plot to show the peak frequency
axes = plt.axes([0.55, 0.3, 0.3, 0.5])
plt.title('Peak frequency')
plt.plot(freqs[:8], power[:8])
plt.setp(axes, yticks=[])

# scipy.signal.find_peaks_cwt can also be used for more advanced
# peak detection
../../../_images/sphx_glr_plot_fftpack_002.png


Remove all the high frequencies


We now remove all the high frequencies and transform back from frequencies to signal.
high_freq_fft = sig_fft.copy()
high_freq_fft[np.abs(sample_freq) > peak_freq] = 0
filtered_sig = fftpack.ifft(high_freq_fft)

plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
plt.plot(time_vec, filtered_sig, linewidth=3, label='Filtered signal')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')

plt.legend(loc='best')
../../../_images/sphx_glr_plot_fftpack_003.png

Note This is actually a bad way of creating a filter: such brutal cut-off in frequency space does not control distorsion on the signal.

Filters should be created using the scipy filter design code

plt.show()

Total running time of the script: ( 0 minutes 0.142 seconds)

"""
=============================================
Plotting and manipulating FFTs for filtering
=============================================
Plot the power of the FFT of a signal and inverse FFT back to reconstruct
a signal.
This example demonstrate :func:`scipy.fftpack.fft`,
:func:`scipy.fftpack.fftfreq` and :func:`scipy.fftpack.ifft`. It
implements a basic filter that is very suboptimal, and should not be
used.
"""

import numpy as np
from scipy import fftpack
from matplotlib import pyplot as plt

############################################################
# Generate the signal
############################################################

# Seed the random number generator
np.random.seed(1234)

time_step = 0.02
period = 5.

time_vec = np.arange(0, 20, time_step)
sig = (np.sin(2 * np.pi / period * time_vec)
       + 0.5 * np.random.randn(time_vec.size))

plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')

############################################################
# Compute and plot the power
############################################################

# The FFT of the signal
sig_fft = fftpack.fft(sig)

# And the power (sig_fft is of complex dtype)
power = np.abs(sig_fft)

# The corresponding frequencies
sample_freq = fftpack.fftfreq(sig.size, d=time_step)

# Plot the FFT power
plt.figure(figsize=(6, 5))
plt.plot(sample_freq, power)
plt.xlabel('Frequency [Hz]')
plt.ylabel('plower')

# Find the peak frequency: we can focus on only the positive frequencies
pos_mask = np.where(sample_freq > 0)
freqs = sample_freq[pos_mask]
peak_freq = freqs[power[pos_mask].argmax()]

# Check that it does indeed correspond to the frequency that we generate
# the signal with
np.allclose(peak_freq, 1./period)

# An inner plot to show the peak frequency
axes = plt.axes([0.55, 0.3, 0.3, 0.5])
plt.title('Peak frequency')
plt.plot(freqs[:8], power[:8])
plt.setp(axes, yticks=[])

# scipy.signal.find_peaks_cwt can also be used for more advanced
# peak detection

############################################################
# Remove all the high frequencies
############################################################
#
# We now remove all the high frequencies and transform back from
# frequencies to signal.

high_freq_fft = sig_fft.copy()
high_freq_fft[np.abs(sample_freq) > peak_freq] = 0
filtered_sig = fftpack.ifft(high_freq_fft)

plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
plt.plot(time_vec, filtered_sig, linewidth=3, label='Filtered signal')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')

plt.legend(loc='best')

############################################################
#
# **Note** This is actually a bad way of creating a filter: such brutal
# cut-off in frequency space does not control distorsion on the signal.
#
# Filters should be created using the scipy filter design code
plt.show()

1.5.12.18 Solutions of the exercises for scipy

Crude periodicity finding

Discover the periods in evolution of animal populations (../../data/populations.txt)


Load the data

import numpy as np
data = np.loadtxt('../../../../data/populations.txt')
years = data[:, 0]
populations = data[:, 1:]


Plot the data

import matplotlib.pyplot as plt
plt.figure()
plt.plot(years, populations * 1e-3)
plt.xlabel('Year')
plt.ylabel('Population number ($\cdot10^3$)')
plt.legend(['hare', 'lynx', 'carrot'], loc=1)
../../../../_images/sphx_glr_plot_periodicity_finder_001.png



Plot its periods

from scipy import fftpack

ft_populations = fftpack.fft(populations, axis=0)
frequencies = fftpack.fftfreq(populations.shape[0], years[1] - years[0])
periods = 1 / frequencies

plt.figure()
plt.plot(periods, abs(ft_populations) * 1e-3, 'o')
plt.xlim(0, 22)
plt.xlabel('Period')
plt.ylabel('Power ($\cdot10^3$)')

plt.show()
../../../../_images/sphx_glr_plot_periodicity_finder_002.png

There’s probably a period of around 10 years (obvious from the plot), but for this crude a method, there’s not enough data to say much more.

Total running time of the script: ( 0 minutes 0.073 seconds)

Curve fitting: temperature as a function of month of the year

We have the min and max temperatures in Alaska for each months of the year. We would like to find a function to describe this yearly evolution.

For this, we will fit a periodic function.



The data

import numpy as np

temp_max = np.array([17,  19,  21,  28,  33,  38, 37,  37,  31,  23,  19,  18])
temp_min = np.array([-62, -59, -56, -46, -32, -18, -9, -13, -25, -46, -52, -58])

import matplotlib.pyplot as plt
months = np.arange(12)
plt.plot(months, temp_max, 'ro')
plt.plot(months, temp_min, 'bo')
plt.xlabel('Month')
plt.ylabel('Min and max temperature')
../../../../_images/sphx_glr_plot_curvefit_temperature_data_001.png


Fitting it to a periodic function

from scipy import optimize
def yearly_temps(times, avg, ampl, time_offset):
    return (avg
            + ampl * np.cos((times + time_offset) * 2 * np.pi / times.max()))

res_max, cov_max = optimize.curve_fit(yearly_temps, months,
                                      temp_max, [20, 10, 0])
res_min, cov_min = optimize.curve_fit(yearly_temps, months,
                                      temp_min, [-40, 20, 0])



Plotting the fit

days = np.linspace(0, 12, num=365)

plt.figure()
plt.plot(months, temp_max, 'ro')
plt.plot(days, yearly_temps(days, *res_max), 'r-')
plt.plot(months, temp_min, 'bo')
plt.plot(days, yearly_temps(days, *res_min), 'b-')
plt.xlabel('Month')
plt.ylabel('Temperature ($^\circ$C)')

plt.show()
../../../../_images/sphx_glr_plot_curvefit_temperature_data_002.png

Total running time of the script: ( 0 minutes 0.072 seconds)

"""
==============================================================
Curve fitting: temperature as a function of month of the year
==============================================================
We have the min and max temperatures in Alaska for each months of the
year. We would like to find a function to describe this yearly evolution.
For this, we will fit a periodic function.
"""

############################################################
# The data
############################################################
import numpy as np

temp_max = np.array([17,  19,  21,  28,  33,  38, 37,  37,  31,  23,  19,  18])
temp_min = np.array([-62, -59, -56, -46, -32, -18, -9, -13, -25, -46, -52, -58])

import matplotlib.pyplot as plt
months = np.arange(12)
plt.plot(months, temp_max, 'ro')
plt.plot(months, temp_min, 'bo')
plt.xlabel('Month')
plt.ylabel('Min and max temperature')

############################################################
# Fitting it to a periodic function
############################################################

from scipy import optimize
def yearly_temps(times, avg, ampl, time_offset):
    return (avg
            + ampl * np.cos((times + time_offset) * 2 * np.pi / times.max()))

res_max, cov_max = optimize.curve_fit(yearly_temps, months,
                                      temp_max, [20, 10, 0])
res_min, cov_min = optimize.curve_fit(yearly_temps, months,
                                      temp_min, [-40, 20, 0])

############################################################
# Plotting the fit
############################################################

days = np.linspace(0, 12, num=365)

plt.figure()
plt.plot(months, temp_max, 'ro')
plt.plot(days, yearly_temps(days, *res_max), 'r-')
plt.plot(months, temp_min, 'bo')
plt.plot(days, yearly_temps(days, *res_min), 'b-')
plt.xlabel('Month')
plt.ylabel('Temperature ($^\circ$C)')

plt.show()

Simple image blur by convolution with a Gaussian kernel

Blur an an image (../../../../data/elephant.png) using a Gaussian kernel.

Convolution is easy to perform with FFT: convolving two signals boils down to multiplying their FFTs (and performing an inverse FFT).

import numpy as np
from scipy import fftpack
import matplotlib.pyplot as plt



The original image

# read image
img = plt.imread('../../../../data/elephant.png')
plt.figure()
plt.imshow(img)
../../../../_images/sphx_glr_plot_image_blur_001.png



Prepare an Gaussian convolution kernel

# First a 1-D  Gaussian
t = np.linspace(-10, 10, 30)
bump = np.exp(-0.1*t**2)
bump /= np.trapz(bump) # normalize the integral to 1

# make a 2-D kernel out of it
kernel = bump[:, np.newaxis] * bump[np.newaxis, :]



Implement convolution via FFT

# Padded fourier transform, with the same shape as the image
# We use :func:`scipy.signal.fftpack.fft2` to have a 2D FFT
kernel_ft = fftpack.fft2(kernel, shape=img.shape[:2], axes=(0, 1))

# convolve
img_ft = fftpack.fft2(img, axes=(0, 1))
# the 'newaxis' is to match to color direction
img2_ft = kernel_ft[:, :, np.newaxis] * img_ft
img2 = fftpack.ifft2(img2_ft, axes=(0, 1)).real

# clip values to range
img2 = np.clip(img2, 0, 1)

# plot output
plt.figure()
plt.imshow(img2)
../../../../_images/sphx_glr_plot_image_blur_002.png

Further exercise (only if you are familiar with this stuff):

A “wrapped border” appears in the upper left and top edges of the image. This is because the padding is not done correctly, and does not take the kernel size into account (so the convolution “flows out of bounds of the image”). Try to remove this artifact.



A function to do it: scipy.signal.fftconvolve()

The above exercise was only for didactic reasons: there exists a function in scipy that will do this for us, and probably do a better job:   scipy.signal.fftconvolve()
from scipy import signal
# mode='same' is there to enforce the same output shape as input arrays
# (ie avoid border effects)
img3 = signal.fftconvolve(img, kernel[:, :, np.newaxis], mode='same')
plt.figure()
plt.imshow(img3)
../../../../_images/sphx_glr_plot_image_blur_003.png

Note that we still have a decay to zero at the border of the image. Usingscipy.ndimage.gaussian_filter() would get rid of this artifact

plt.show()

Total running time of the script: ( 0 minutes 0.195 seconds)

"""
=======================================================
Simple image blur by convolution with a Gaussian kernel
=======================================================
Blur an an image (:download:`../../../../data/elephant.png`) using a
Gaussian kernel.
Convolution is easy to perform with FFT: convolving two signals boils
down to multiplying their FFTs (and performing an inverse FFT).
"""

import numpy as np
from scipy import fftpack
import matplotlib.pyplot as plt

#####################################################################
# The original image
#####################################################################

# read image
img = plt.imread('../../../../data/elephant.png')
plt.figure()
plt.imshow(img)

#####################################################################
# Prepare an Gaussian convolution kernel
#####################################################################

# First a 1-D  Gaussian
t = np.linspace(-10, 10, 30)
bump = np.exp(-0.1*t**2)
bump /= np.trapz(bump) # normalize the integral to 1

# make a 2-D kernel out of it
kernel = bump[:, np.newaxis] * bump[np.newaxis, :]

#####################################################################
# Implement convolution via FFT
#####################################################################

# Padded fourier transform, with the same shape as the image
# We use :func:`scipy.signal.fftpack.fft2` to have a 2D FFT
kernel_ft = fftpack.fft2(kernel, shape=img.shape[:2], axes=(0, 1))

# convolve
img_ft = fftpack.fft2(img, axes=(0, 1))
# the 'newaxis' is to match to color direction
img2_ft = kernel_ft[:, :, np.newaxis] * img_ft
img2 = fftpack.ifft2(img2_ft, axes=(0, 1)).real

# clip values to range
img2 = np.clip(img2, 0, 1)

# plot output
plt.figure()
plt.imshow(img2)

#####################################################################
# Further exercise (only if you are familiar with this stuff):
#
# A "wrapped border" appears in the upper left and top edges of the
# image. This is because the padding is not done correctly, and does
# not take the kernel size into account (so the convolution "flows out
# of bounds of the image").  Try to remove this artifact.


#####################################################################
# A function to do it: :func:`scipy.signal.fftconvolve`
#####################################################################
#
# The above exercise was only for didactic reasons: there exists a
# function in scipy that will do this for us, and probably do a better
# job: :func:`scipy.signal.fftconvolve`

from scipy import signal
# mode='same' is there to enforce the same output shape as input arrays
# (ie avoid border effects)
img3 = signal.fftconvolve(img, kernel[:, :, np.newaxis], mode='same')
plt.figure()
plt.imshow(img3)

#####################################################################
# Note that we still have a decay to zero at the border of the image.
# Using :func:`scipy.ndimage.gaussian_filter` would get rid of this
# artifact


plt.show()


Image denoising by FFT

Denoise an image (../../../../data/moonlanding.png) by implementing a blur with an FFT.

Implements, via FFT, the following convolution:

f_1(t) = \int dt'\, K(t-t') f_0(t')

\tilde{f}_1(\omega) = \tilde{K}(\omega) \tilde{f}_0(\omega)



Read and plot the image

import numpy as np
import matplotlib.pyplot as plt

im = plt.imread('../../../../data/moonlanding.png').astype(float)

plt.figure()
plt.imshow(im, plt.cm.gray)
plt.title('Original image')
../../../../_images/sphx_glr_plot_fft_image_denoise_001.png


Compute the 2d FFT of the input image

from scipy import fftpack
im_fft = fftpack.fft2(im)

# Show the results

def plot_spectrum(im_fft):
    from matplotlib.colors import LogNorm
    # A logarithmic colormap
    plt.imshow(np.abs(im_fft), norm=LogNorm(vmin=5))
    plt.colorbar()

plt.figure()
plot_spectrum(im_fft)
plt.title('Fourier transform')
../../../../_images/sphx_glr_plot_fft_image_denoise_002.png



Filter in FFT

# In the lines following, we'll make a copy of the original spectrum and
# truncate coefficients.

# Define the fraction of coefficients (in each direction) we keep
keep_fraction = 0.1

# Call ff a copy of the original transform. Numpy arrays have a copy
# method for this purpose.
im_fft2 = im_fft.copy()

# Set r and c to be the number of rows and columns of the array.
r, c = im_fft2.shape

# Set to zero all rows with indices between r*keep_fraction and
# r*(1-keep_fraction):
im_fft2[int(r*keep_fraction):int(r*(1-keep_fraction))] = 0

# Similarly with the columns:
im_fft2[:, int(c*keep_fraction):int(c*(1-keep_fraction))] = 0

plt.figure()
plot_spectrum(im_fft2)
plt.title('Filtered Spectrum')
../../../../_images/sphx_glr_plot_fft_image_denoise_003.png


Reconstruct the final image

# Reconstruct the denoised image from the filtered spectrum, keep only the
# real part for display.
im_new = fftpack.ifft2(im_fft2).real

plt.figure()
plt.imshow(im_new, plt.cm.gray)
plt.title('Reconstructed Image')
../../../../_images/sphx_glr_plot_fft_image_denoise_004.png


Easier and better:scipy.ndimage.gaussian_filter()


Implementing filtering directly with FFTs is tricky and time consuming. We can use the Gaussian filter from   scipy.ndimage
from scipy import ndimage
im_blur = ndimage.gaussian_filter(im, 4)

plt.figure()
plt.imshow(im_blur, plt.cm.gray)
plt.title('Blurred image')

plt.show()
../../../../_images/sphx_glr_plot_fft_image_denoise_005.png

Total running time of the script: ( 0 minutes 0.381 seconds)

r"""
======================
Image denoising by FFT
======================
Denoise an image (:download:`../../../../data/moonlanding.png`) by
implementing a blur with an FFT.
Implements, via FFT, the following convolution:
.. math::
    f_1(t) = \int dt'\, K(t-t') f_0(t')
.. math::
    \tilde{f}_1(\omega) = \tilde{K}(\omega) \tilde{f}_0(\omega)
"""

############################################################
# Read and plot the image
############################################################
import numpy as np
import matplotlib.pyplot as plt

im = plt.imread('../../../../data/moonlanding.png').astype(float)

plt.figure()
plt.imshow(im, plt.cm.gray)
plt.title('Original image')


############################################################
# Compute the 2d FFT of the input image
############################################################
from scipy import fftpack
im_fft = fftpack.fft2(im)

# Show the results

def plot_spectrum(im_fft):
    from matplotlib.colors import LogNorm
    # A logarithmic colormap
    plt.imshow(np.abs(im_fft), norm=LogNorm(vmin=5))
    plt.colorbar()

plt.figure()
plot_spectrum(im_fft)
plt.title('Fourier transform')

############################################################
# Filter in FFT
############################################################

# In the lines following, we'll make a copy of the original spectrum and
# truncate coefficients.

# Define the fraction of coefficients (in each direction) we keep
keep_fraction = 0.1

# Call ff a copy of the original transform. Numpy arrays have a copy
# method for this purpose.
im_fft2 = im_fft.copy()

# Set r and c to be the number of rows and columns of the array.
r, c = im_fft2.shape

# Set to zero all rows with indices between r*keep_fraction and
# r*(1-keep_fraction):
im_fft2[int(r*keep_fraction):int(r*(1-keep_fraction))] = 0

# Similarly with the columns:
im_fft2[:, int(c*keep_fraction):int(c*(1-keep_fraction))] = 0

plt.figure()
plot_spectrum(im_fft2)
plt.title('Filtered Spectrum')


############################################################
# Reconstruct the final image
############################################################

# Reconstruct the denoised image from the filtered spectrum, keep only the
# real part for display.
im_new = fftpack.ifft2(im_fft2).real

plt.figure()
plt.imshow(im_new, plt.cm.gray)
plt.title('Reconstructed Image')


############################################################
# Easier and better: :func:`scipy.ndimage.gaussian_filter`
############################################################
#
# Implementing filtering directly with FFTs is tricky and time consuming.
# We can use the Gaussian filter from :mod:`scipy.ndimage`

from scipy import ndimage
im_blur = ndimage.gaussian_filter(im, 4)

plt.figure()
plt.imshow(im_blur, plt.cm.gray)
plt.title('Blurred image')

plt.show()

See also

 

References to go further



PythonSciPy 库是一个用于科学计算和工程应用的强大工具,它构建在 NumPy 基础之上,提供了许多高级的数学函数、统计方法、信号处理、图像处理以及优化算法等功能。SciPy 是进行数据分析、建模和研究的理想选择。 ### 主要用途 1. **科学计算与数学建模** SciPy 提供了线性代数、积分、插值、傅里叶变换等数学功能,适用于物理模拟、数值分析等领域[^1]。 2. **信号处理** 利用 `scipy.signal` 模块可以对信号进行滤波、卷积、频谱分析等操作,广泛应用于通信系统、音频处理等方面[^4]。 3. **最优化问题求解** SciPy 包含多种优化算法,能够解决无约束和有约束的最优化问题,适用于机器学习模型参数调整、工程设计优化等场景[^5]。 4. **统计分析** 提供了丰富的统计分布函数、假设检验工具,适合用于数据分析和概率建模。 5. **图像处理** 通过 `scipy.ndimage` 可以实现图像的滤波、旋转、缩放等操作。 ### 基本使用方法 #### 安装与导入 安装 SciPy 需要先确保已安装 Python 和 NumPy。通常可以通过 pip 安装: ```bash pip install scipy ``` 在 Python 脚本中导入 SciPy 的方式如下: ```python import scipy ``` 如果仅需要某个子模块,例如信号处理模块,则可以单独导入: ```python from scipy import signal ``` #### 示例:使用 SciPy 进行低通滤波 以下是一个简单的低通滤波器示例,使用 `scipy.signal.butter` 创建一个 Butterworth 滤波器,并用 `scipy.signal.filtfilt` 对信号进行零相位滤波: ```python import numpy as np from scipy import signal import matplotlib.pyplot as plt # 生成带噪声的正弦信号 fs = 1000 # 采样率 t = np.linspace(0, 1, fs, endpoint=False) x = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.randn(t.size) # 设计低通滤波器 sos = signal.butter(4, 10, btype=&#39;low&#39;, fs=fs, output=&#39;sos&#39;) # 应用滤波器 y = signal.sosfilt(sos, x) # 绘制原始信号与滤波后信号 plt.figure() plt.plot(t, x, label=&#39;Noisy signal&#39;) plt.plot(t, y, &#39;g&#39;, linewidth=2, label=&#39;Filtered signal&#39;) plt.legend() plt.show() ``` 该示例展示了如何利用 SciPy 处理现实世界中的信号数据,去除高频噪声并保留有用信息[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值