numpy矩阵乘法的解惑

#源码如下: 批量梯度下降法

import numpy as np
# Setting a random seed, feel free to change it and see different solutions.
np.random.seed(42)


# TODO: Fill in code in the function below to implement a gradient descent
# step for linear regression, following a squared error rule. See the docstring
# for parameters and returned variables.
def MSEStep(X, y, W, b, learn_rate = 0.005):
    """
    This function implements the gradient descent step for squared error as a
    performance metric.
    
    Parameters
    X : array of predictor features
    y : array of outcome values
    W : predictor feature coefficients
    b : regression function intercept
    learn_rate : learning rate

    Returns
    W_new : predictor feature coefficients following gradient descent step
    b_new : intercept following gradient descent step
    """
    
    # Fill in code
    
    y_pred = np.matmul(X, W) + b
    
    print("np.matmul(X, W) shi ge sha:",np.matmul(X, W) ,"np.matmul(X, W)=",np.matmul(X, W).shape)
    error = y - y_pred
    
    # compute steps
    W_new = W + learn_rate * np.matmul(error, X)
    print("np.matmul(error, X).shape=",np.matmul(error, X).shape," W.shape=",W.shape,"err.shape=",error.shape,"X.shape=",X.shape)
    b_new = b + learn_rate * error.sum()
    return W_new, b_new


# The parts of the script below will be run when you press the "Test Run"
# button. The gradient descent step will be performed multiple times on
# the provided dataset, and the returned list of regression coefficients
# will be plotted.
def miniBatchGD(X, y, batch_size = 20, learn_rate = 0.005, num_iter = 25):
    """
    This function performs mini-batch gradient descent on a given dataset.

    Parameters
    X : array of predictor features
    y : array of outcome values
    batch_size : how many data points will be sampled for each iteration
    learn_rate : learning rate
    num_iter : number of batches used

    Returns
    regression_coef : array of slopes and intercepts generated by gradient
      descent procedure
    """
    
    n_points = X.shape[0]
    W = np.zeros(X.shape[1]) # coefficients
    b = 0 # intercept
    print("typex=",type(X),"typeW=",type(W))
    print("type(y)=",type(y),"typeB=",type(b))
    print("X.shape[0]=",X.shape[0])
    print("X=",X,"w=",W)
    # run iterations
    regression_coef = [np.hstack((W,b))]
    for _ in range(num_iter):
        batch = np.random.choice(range(n_points), batch_size)
        if _==0:
            print("type(batch)",type(batch))
        X_batch = X[batch,:]
        y_batch = y[batch]
        W, b = MSEStep(X_batch, y_batch, W, b, learn_rate)
        regression_coef.append(np.hstack((W,b)))
    
    return regression_coef


if __name__ == "__main__":
    # perform gradient descent
    data = np.loadtxt('data.csv', delimiter = ',')
    X = data[:,:-1]
    y = data[:,-1]
   
    regression_coef = miniBatchGD(X, y)
    
    # plot the results
    import matplotlib.pyplot as plt
    
    plt.figure()
    X_min = X.min()
    X_max = X.max()
    counter = len(regression_coef)
    for W, b in regression_coef:
        counter -= 1
        color = [1 - 0.92 ** counter for _ in range(3)]
        plt.plot([X_min, X_max],[X_min * W + b, X_max * W + b], color = color)
    plt.scatter(X, y, zorder = 3)
    plt.show()

 

#$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

#有几点疑惑澄清:

1 矩阵相乘遇到shape为一维数组,即(n,)类似形式实际是当作1*n向量来处理的,与1*n的区别是需要将一维数组最外层再加层[]就可以了显示获得shape=(1,n)了

numpy.zeros() 返回的默认的都是(n,)的一维数组 ,实际是1*n的数组

  err.shape= (20,) X.shape= (20, 1)  np.matmul(error, X).shape= (1,)

2 向量+常量b是按照每个向量都加上该常量b来处理的

y_pred = np.matmul(X, W) + b

X.shape= (20, 1) 

W.shape= (1, ) 

np.matmul(X, W)= (20,)

Loading data... Performng gradient descent (default params)... typex= <class 'numpy.ndarray'> typeW= <class 'numpy.ndarray'> type(y)= <class 'numpy.ndarray'> typeB= <class 'int'> X.shape[0]= 100 X= [[-7.24070e-01] [-2.40724e+00] [ 2.64837e+00] [ 3.60920e-01] [ 6.73120e-01] [-4.54600e-01] [ 2.20168e+00] [ 1.15605e+00] [ 5.06940e-01] [-8.59520e-01] [-5.99700e-01] [ 1.46804e+00] [-1.05659e+00] [ 1.29177e+00] [-7.45650e-01] [ 1.50330e-01] [-1.49627e+00] [-7.20710e-01] [ 3.29240e-01] [-2.80530e-01] [-1.36115e+00] [ 7.46780e-01] [ 1.06210e-01] [ 3.25600e-02] [-9.82900e-01] [-1.15661e+00] [ 9.02400e-02] [-1.03816e+00] [-6.04000e-03] [ 1.62780e-01] [-6.98690e-01] [ 1.03857e+00] [-1.17830e-01] [-9.54090e-01] [-8.18390e-01] [-1.28802e+00] [ 6.28220e-01] [-2.29674e+00] [-8.56010e-01] [-1.75223e+00] [-1.19662e+00] [ 9.77810e-01] [-1.17110e+00] [ 1.58350e-01] [-5.89180e-01] [-1.79678e+00] [-9.57270e-01] [ 6.45560e-01] [ 2.46250e-01] [ 4.59170e-01] [ 1.21036e+00] [-6.01160e-01] [ 2.68510e-01] [ 4.95940e-01] [-2.67877e+00] [ 4.94020e-01] [ 1.18643e+00] [-1.77410e-01] [ 5.79380e-01] [-2.14926e+00] [ 2.27700e+00] [-1.05695e+00] [ 1.68288e+00] [-1.53513e+00] [ 9.90000e-04] [ 4.55200e-01] [-3.78550e-01] [ 1.35638e+00] [ 1.76300e-02] [ 2.21725e+00] [-4.44420e-01] [ 8.95830e-01] [ 1.30499e+00] [ 1.08830e-01] [ 1.79466e+00] [-7.33000e-03] [ 7.98620e-01] [-1.23530e-01] [-1.34999e+00] [-6.78250e-01] [-1.79010e-01] [ 1.25770e-01] [ 1.11943e+00] [-3.02296e+00] [ 6.49650e-01] [ 1.05994e+00] [ 5.33600e-01] [-7.35910e-01] [-9.56900e-02] [ 1.04694e+00] [ 4.65110e-01] [-7.54630e-01] [-9.41590e-01] [-9.31400e-02] [-9.86410e-01] [-9.21590e-01] [ 7.69530e-01] [ 3.28300e-02] [-1.07619e+00] [ 2.01740e-01]] w= [0.] type(batch) <class 'numpy.ndarray'> np.matmul(X, W) shi ge sha: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.29708128 -0.18945281 -0.26524356 -0.16798167 -0.01454159 -0.00090461 -0.02189445 0.0921613 -0.01180925 0.0303901 0.05739996 0.0715022 0.12067307 -0.09313008 -0.26524356 -0.0837039 -0.09202184 -0.13043986 -0.13043986 -0.1181382 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.0815292 0.09336268 -0.20663037 -0.11841423 0.20428543 0.09336268 0.16982942 -0.10499406 0.10066208 -0.06602943 -0.05559289 0.02784003 -0.00738117 -0.16578594 0.09964234 0.03910347 0.08075707 0.02071186 -0.18568556 -0.23317991] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.10558319 -0.30071126 -0.06218052 0.02215536 0.16174741 -0.13393534 -0.1145041 -0.01728356 -0.02504598 -0.18021185 0.06424425 0.05049771 -0.33680623 -0.06360484 0.06938888 0.05049771 0.06938888 -0.13174149 0.23545823 -0.10083732] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.14675307 0.02219645 -0.13373798 0.01525506 -0.14815621 0.02827857 0.18107171 -0.13826838 0.0904903 -0.10452025 0.12557148 -0.01731561 0.07479649 -0.14815621 -0.24561593 0.09106361 -0.09507257 0.0176296 0.03763794 0.00456404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.79445215e-01 -1.48460697e-02 -3.33452437e-01 -1.85652669e-01 -9.37091240e-04 -1.15685775e-01 -9.14098372e-02 1.53596081e-04 -1.48460697e-02 -6.89506770e-02 7.86505022e-02 -1.14174638e-01 -1.12337691e-01 1.79358332e-01 -1.14174638e-01 2.61094720e-01 -9.30419895e-02 -2.77729641e-02 1.79358332e-01 -1.26971209e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.1230817 -0.01772103 0.10123383 -0.17996557 -0.15613399 0.33112149 0.19626387 0.13472828 0.22078576 -0.14349029 -0.01772103 0.09708895 0.01597344 -0.15895991 -0.11067713 0.0944811 -0.16185351 0.02381504 0.15940959 0.06995018] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.62509877e-01 1.96586931e-04 -2.13701908e-01 -1.90087648e-01 -2.45296804e-02 5.25893869e-01 -1.43780504e-01 1.33663227e-01 2.07893658e-01 2.56510202e-01 1.79191966e-02 1.00664423e-01 -2.68071102e-01 -1.48065702e-01 2.07893658e-01 1.94166330e-01 1.58584096e-01 2.40344402e-01 3.34173954e-01 -1.83002575e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.15895533 -0.24368162 0.09543243 -0.02462755 0.02808345 -0.19716056 -0.30965565 -0.00159706 -0.34057097 -0.27937756 0.27461281 -0.11751103 0.15319639 0.28026335 -0.2745044 0.12036141 0.258547 -0.15578765 -0.27947275 0.31370912] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.12939871 -0.29550498 -0.29550498 0.04507322 -0.86046441 0.04633419 -0.30085342 0.51083741 -0.21480015 -0.02723749 -0.30085342 0.21904133 -0.20610146 0.02568618 -0.30085342 0.22732159 0.75383998 0.63112469 0.25499174 0.02568618] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.15803179 -0.32949007 0.19583921 0.23989072 0.3773136 0.0493635 0.01015015 -0.42084139 0.18061399 0.32376035 -0.2872934 -0.22941013 -0.18740362 -0.32949007 -0.0553052 -0.18740362 0.4576419 -0.26684971 -0.75042498 0.82559404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.16429693 -0.05953361 0.19268523 -0.80057921 -0.80057921 -0.25096837 0.16493547 0.17746011 -0.30649449 0.25592368 -0.24080498 0.10949581 -0.80057921 0.08929875 0.05266268 0.34818232 0.34539869 0.73739397 0.34539869 0.45109321] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.01041465 0.6769341 0.18634114 -0.56438333 -0.86631542 0.01228142 0.00664992 0.01238326 0.8363323 0.39980249 -0.22620295 0.0567035 0.29026172 0.49223375 0.21853837 0.8363323 -0.2558315 -0.35516163 0.99894797 -0.10581409] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.22397949 -0.49792892 0.12727917 0.40473107 -0.14634167 0.12727917 -0.10844863 -0.35627266 -0.17180601 -0.2323993 -0.04555128 -0.67738544 -0.33092043 0.04862077 -0.27991444 -0.23183489 -0.29172847 0.45865578 -0.0369923 0.17750813] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.03997731 -0.26298605 -0.03505778 0.36804648 -0.40507659 0.82870964 0.05658403 0.39406693 -0.80897881 -0.90608217 -0.27253906 0.24298799 0.55256845 0.00663591 0.23646123 0.39091647 0.19081159 -0.40507659 0.12392553 0.24298799] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.27434967 1.07125609 -0.11347337 0.0131704 0.20060594 -0.04766181 0.0131704 0.72593348 0.36236 -0.52099943 -0.92902303 -1.22277639 -0.43531496 -0.03870626 -0.43531496 -0.3975795 -0.38086975 -0.29152459 0.05087351 0.18412675] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.19103629 -0.31524416 -0.26126218 0.87710133 -0.46093851 0.07771033 -0.04538816 0.52247813 -0.04538816 0.49759033 -0.52431554 0.24866998 -0.10806027 0.44531093 0.84808803 -0.14581761 -0.57636381 -0.04538816 0.24866998 -0.00282352] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.24169333 0.44216306 0.31152568 0.06790507 -0.95810478 0.50491205 0.20688562 0.48225617 0.03764439 0.03764439 0.01369532 -0.56781539 0.06790507 0.32101604 -0.43307734 -0.64039264 0.32101604 0.00735451 0.94987007 0.26930089] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.13640584 0.14953103 -0.33906323 -0.63601233 0.10202265 -0.61991243 0.06560523 -0.31264713 0.0674406 -0.39010563 -0.74441528 -0.18834314 -0.44587111 0.26027481 0.01348978 -0.39010563 -0.74441528 0.11124509 -0.40867478 -0.44587111] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.85631074 0.30659706 0.67049506 0.25883433 0.4137883 0.21259755 -0.04694597 -0.15082234 -0.28714614 -0.39160819 -0.39300665 0.19759301 -0.0706839 -0.15082234 -0.71587523 0.01297259 0.4137883 -0.38139666 0.42230256 0.04231632] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.17931287 0.03554744 -0.94826256 0.41241172 -0.58941228 -0.04641572 0.19969435 -0.46132096 0.25429969 -0.33720037 -0.37091214 0.38517996 -0.45561305 -0.42393392 0.18087674 -0.38718502 0.01282607 -0.41621306 -0.8466388 0.86728814] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 4.99166150e-01 -5.41908065e-01 -2.47885432e-01 -1.18027258e-01 4.16522247e-04 4.86384387e-01 6.32482720e-02 5.43485800e-01 -3.08394754e-03 2.24501284e-01 -3.13716983e-01 -3.17495134e-01 1.38125509e-02 1.91516087e-01 4.36957080e-01 2.24501284e-01 7.08037332e-01 4.45948071e-01 5.09234208e-01 -4.13535067e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.07580145 0.71903918 -0.45159991 0.31907449 -0.07580145 -0.07580145 0.45287744 0.10521451 -0.25685586 0.417786 0.94735788 -0.31859168 0.21189882 -0.91830799 0.32879482 0.49394208 0.11472548 -0.91830799 0.287602 0.57953649] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.2075635 -0.42078408 -0.04252632 -0.43707503 -0.49137211 -1.22308595 -0.80004177 -0.27448058 0.06863841 -0.48242304 0.07432289 0.15032601 -0.6831743 0.7683776 0.15032601 -0.34455267 -0.08100273 -1.22308595 0.478017 0.478017 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.4766563 0.04789788 0.22861644 0.66204694 -0.32653629 -0.08000718 -0.32653629 -0.43026918 -0.41561254 0.29113037 -0.04315364 -0.32653629 0.06779483 1.02686635 -0.69230274 0.75893406 0.00795066 0.33677789 -0.42463201 -0.17071597] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) Plotting the results... Regression lines start from the lightest line, with the darkest, black line as the last line. Do you see it getting closer to the data over each iteration?

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值