您的位置:首页 > 编程语言 > Python开发

Python Intro - Numpy array shape

2015-04-15 12:27 393 查看
#!/usr/local/bin/python3

try:

import numpy as np

except ImportError:

print("numpy is not installed")

sizes = [37, 5, 6];

#biases = np.array([ 1.0, 3.0, 5.0, 6.0, 7.0]);

#weights = np.array([ 11.0, 33.0, 45.0, 26.0, 17.0, 88.0]);

biases = [np.random.randn(y, 1) for y in sizes[1:]];

weights = [np.random.randn(y, x) for x, y in zip(sizes[:-1], sizes[1:])]

nabla_b = [np.zeros(b.shape) for b in biases]

nabla_w = [np.zeros(w.shape) for w in weights]

#print("sizes = ", sizes);

#print("biases = ", biases);

#print("weights = ", weights);

#print("nabla_b = ", nabla_b);

#print("nabla_w = ", nabla_w);

print("\n\n\n");

print("===============================================================");

print("\n\n\n");

for b, w in zip(biases, weights):

print(b.shape);

t1=b.shape;

print("b = ", b);

print("\n");

print(w.shape);

print("w = ", w);

print("\n");

#a=[np.random.randn(w.shape[1], 1)];

a=[np.random.randn(w.shape[1], 1)][0];

print("a = ", a);

print("\n");

t = np.dot(w, a) + b;

print("result = ", t);

print("-------------------------------------------------------");

print("\n");

print("\n\n\n");

print("===============================================================");

a=[np.random.randn(w.shape[1], 1)]; 将导致错误的结果,

而 a=[np.random.randn(w.shape[1], 1)][0];将产生正确的结果。

另外 shape 是 Numpy array 的一个属性。。。

错误的结果如下:

===============================================================

(5, 1)

b = [[-0.80234381]

[-0.81584549]

[ 0.19263824]

[ 0.23136913]

[-0.35369136]]

(5, 37)

w = [[-1.22492256 -0.05153998 -0.2416596 0.50049006 -1.49660207 0.85092651

0.6544673 -0.3754299 -1.80912589 0.37908488 0.42568476 1.49655684

0.47841895 -1.22748695 -1.44565474 1.59204204 0.54425796 -0.99120238

-0.56170821 0.70737275 -1.00461932 1.3174263 0.11169217 0.06186711

-0.85344359 1.16538531 0.41580291 0.01258008 0.20344186 0.41829162

0.20227069 0.97059174 -0.05373197 -2.19291387 -0.4340676 1.85237748

0.27900507]

[-1.45168005 -1.14344156 -0.28568769 0.25735952 -1.26321918 -0.27439448

0.70492131 -0.90084135 -0.34983977 -0.37380968 -0.83555351 -0.8887116

0.13211296 -1.96430847 0.79747858 0.50058352 -0.37602784 -0.50755738

1.14404825 -0.73626532 0.98610876 0.21031306 -1.51556671 -1.34774195

-0.22559543 -1.4108586 0.62301475 2.17360229 -0.62072268 1.36621206

-0.33187264 -1.72646447 0.17478133 -0.19081233 -0.52705278 0.47450632

-1.07083316]

[-0.11320502 0.73540301 0.21444072 -0.12534763 1.66133067 -0.88573318

0.19579348 -1.14278562 1.06703606 -0.87959109 0.05832806 -2.84410938

0.7067185 0.15948225 1.50135313 2.05782267 1.0439752 -0.05334146

-2.39704177 -0.03350789 -1.06778715 0.10799586 -0.49407107 -0.06985971

-1.59455816 0.4528298 -0.42533419 0.46746003 -0.20312981 -0.09228466

-0.43995046 0.51760939 0.24436896 -0.39869921 -0.86469663 -2.68340844

0.2734433 ]

[ 0.42368235 0.82065636 0.33775918 -0.6113859 -1.66232777 0.10907928

0.61912058 0.47998729 -1.11836201 -0.40173418 -0.30956985 -0.52852956

-0.98234968 -1.19866915 -0.19059407 -0.52731314 0.51324251 -0.12440913

0.05918655 -0.77988089 -0.4261943 0.44117496 -0.269059 -1.19496792

-0.4348697 -1.11977948 -0.24246792 2.04652802 -0.69826327 -1.84210738

-1.97253499 -0.74212964 0.38386701 0.85001423 0.80454352 0.76351279

0.40977939]

[-0.25813088 1.81696984 -0.23022431 -0.30215019 1.05691976 0.82763261

0.94810037 -1.82938546 -0.6754055 -0.20740134 -0.19331754 1.14308525

-1.64443817 -0.92892617 0.09999019 0.31131367 -1.38775837 0.12595169

0.02417245 1.72788269 0.47844719 0.35279001 -0.23732005 0.66667847

0.27974457 0.7541109 1.49282702 -1.59411123 0.24738603 -0.33174431

-1.11101875 -0.77369412 0.21519189 0.70745882 -0.93253936 -0.33803257

-0.18741942]]

a = [array([[ 0.43080953],

[ 0.6271243 ],

[ 0.25769139],

[ 1.15770288],

[ 0.88505929],

[-1.18651882],

[-2.0328503 ],

[-0.33605496],

[-0.7184775 ],

[ 0.2481995 ],

[ 0.791939 ],

[ 0.82346004],

[-0.95404354],

[-0.80400157],

[-0.85874463],

[-0.19380442],

[-1.4377476 ],

[ 0.74277695],

[ 2.1718078 ],

[-1.42753621],

[-2.1769708 ],

[ 0.6066664 ],

[ 0.89533354],

[ 0.40967936],

[ 1.28491686],

[-0.55703151],

[-0.25216725],

[-0.46927728],

[-1.11216852],

[ 0.21308665],

[-1.23149161],

[ 0.98024705],

[ 0.25706576],

[-0.61663706],

[-1.2014525 ],

[ 0.80797824],

[ 0.41965466]])]

result = [[[ 1.58854826]

[ 1.57504658]

[ 2.58353031]

[ 2.6222612 ]

[ 2.03720071]]

[[ -4.88910208]

[ -4.90260375]

[ -3.89412003]

[ -3.85538914]

[ -4.44044963]]

[[-10.38698865]

[-10.40049033]

[ -9.3920066 ]

[ -9.35327571]

[ -9.9383362 ]]

[[ -0.03147276]

[ -0.04497444]

[ 0.96350929]

[ 1.00224018]

[ 0.41717969]]

[[ 1.71963699]

[ 1.70613532]

[ 2.71461904]

[ 2.75334993]

[ 2.16828944]]]

-------------------------------------------------------

(6, 1)

b = [[-0.01942626]

[-0.31193881]

[ 0.43504691]

[ 1.84491166]

[-0.80456819]

[ 0.03021581]]

(6, 5)

w = [[ -1.13538691e+00 -1.17284504e+00 -6.81504250e-01 2.59648757e-02

3.42289781e-01]

[ 8.39006451e-01 3.36292700e-01 1.58288576e+00 -1.25206326e-01

1.37358527e-01]

[ -4.62542151e-01 -9.32414390e-01 1.05534508e-03 1.99062893e+00

5.84544636e-01]

[ 4.63710049e-01 2.59477781e-01 -6.14149339e-02 5.65044304e-01

-6.04752454e-01]

[ 3.73561048e-01 -1.14715510e-01 -3.71172780e-01 -6.14595047e-01

-8.42810547e-01]

[ 6.79003560e-01 2.42944760e-01 1.59213349e+00 -2.18025362e-01

1.55950700e+00]]

a = [array([[-0.40012735],

[-0.42808614],

[ 1.13635464],

[ 0.78782013],

[ 1.44406382]])]

result = [[[ 0.67726523]

[ 0.38475268]

[ 1.1317384 ]

[ 2.54160315]

[-0.1078767 ]

[ 0.7269073 ]]

[[ 1.39933608]

[ 1.10682352]

[ 1.85380925]

[ 3.26367399]

[ 0.61419414]

[ 1.44897814]]

[[ 2.97837974]

[ 2.68586718]

[ 3.43285291]

[ 4.84271765]

[ 2.1932378 ]

[ 3.0280218 ]]

[[-0.81398518]

[-1.10649774]

[-0.35951201]

[ 1.05035273]

[-1.59912712]

[-0.76434312]]

[[-2.24283661]

[-2.53534916]

[-1.78836344]

[-0.37849869]

[-3.02797854]

[-2.19319454]]

[[ 3.49437571]

[ 3.20186316]

[ 3.94884888]

[ 5.35871363]

[ 2.70923378]

[ 3.54401778]]]

-------------------------------------------------------

===============================================================

正确的结果如下:

===============================================================

(5, 1)

b = [[-0.262492 ]

[-0.20629397]

[-0.01950833]

[ 0.74556297]

[-0.59764296]]

(5, 37)

w = [[-0.9970828 0.73065033 -0.4922562 2.42612649 -0.34932448 0.47006178

1.54562521 0.60566196 -0.02559831 -0.87583074 0.59756882 -0.7147512

-0.27414287 0.8355061 -0.36815386 -0.02422284 -0.12012678 1.31111424

-0.29971747 -0.53969772 1.44673392 1.78797172 -0.15529384 -0.27336318

-1.77369675 -0.85413537 1.35460962 -0.14779111 -0.78557714 -0.02437009

-0.7033722 -0.54030338 -0.72616893 -0.16379103 -0.29007255 -0.25834042

1.36683428]

[-0.56872961 0.60951407 -0.35310149 0.52013285 0.61663699 -0.30961368

-0.20963367 0.71705772 -0.5316324 0.25423788 -1.68829098 0.10111508

-0.70798256 0.50275926 -0.49245537 -0.95913379 0.5142013 0.58778145

-0.94998304 -0.76775667 -0.34393914 -0.38942604 -1.4718932 -1.03127752

0.14620239 -1.84580584 -1.63070239 -0.77504503 0.73485183 0.12520007

-0.938779 1.83107633 -0.93583973 -0.62045299 1.23843946 0.14997827

1.56348389]

[ 0.33972002 -1.18879186 0.49104202 -0.07079901 1.08109047 -1.3101739

-0.86728255 -0.64579715 1.10382172 -1.58168163 -0.24854551 -0.90613182

-1.65663682 0.58143426 0.27627428 0.16452686 0.8565011 0.42513523

-2.21190155 0.61925565 1.11953894 1.02110787 -0.17919635 1.95764644

0.58239631 1.63533897 -0.4562275 -0.67159646 0.41412204 0.11597722

0.52239521 -1.55371469 -0.02140371 0.47340526 -0.60128875 -0.98142806

-1.18309974]

[ 0.60442293 -0.11307345 1.60810941 -1.90193352 0.36896755 -1.055618

1.33453487 -0.0839704 0.21068519 1.5959486 0.4395267 0.8754726

-0.58788291 0.88672996 1.55037176 -1.5332406 0.57252895 1.54376175

-0.20691328 0.77768875 -0.8624152 0.4351931 1.23382085 0.34031488

-0.27569002 0.48206992 0.2324622 -0.18764399 0.18933579 0.43943348

-0.0666472 -0.0074146 0.44300101 -0.58442105 1.58259841 -1.45337009

-0.12645283]

[-1.17699839 1.42546676 -0.06306605 -1.35392344 -0.52848191 1.12807458

1.12138442 0.95538687 -0.32399031 -1.02436395 -0.77799109 0.5657171

-0.52596777 -0.74161448 -0.15159113 0.64114591 1.31848493 1.50740195

-0.36352739 2.26293174 0.00654035 -1.46270206 -0.31493683 0.35115431

-0.64182191 0.70758048 0.17557849 -1.0962026 -0.55330834 -0.1487515

0.85021598 -0.98307117 0.97868651 -0.51840716 -1.23989995 -0.20878353

-0.29166749]]

a = [[ 0.57490808]

[ 0.93551352]

[ 0.48746221]

[-1.05142943]

[ 0.17675075]

[ 0.30799841]

[-1.85511516]

[-0.1011902 ]

[ 0.17074073]

[ 0.07103635]

[-0.3394694 ]

[-1.21016465]

[ 2.27487407]

[ 0.1694055 ]

[ 0.2834597 ]

[-0.48268838]

[ 1.69990823]

[ 0.18432417]

[ 0.41151327]

[ 0.08262408]

[ 0.11720957]

[-1.70657664]

[ 0.47772809]

[-0.47040207]

[-0.50707268]

[-0.50235741]

[ 0.1109001 ]

[ 0.0418734 ]

[-0.91857133]

[-1.07522034]

[ 0.02980418]

[-0.83182824]

[-0.98320158]

[ 1.87120068]

[ 0.26556049]

[-1.26170258]

[-0.45455707]]

result = [[-6.04175764]

[-2.77462728]

[-1.38691038]

[ 0.85500108]

[ 1.63308819]]

-------------------------------------------------------

(6, 1)

b = [[ 0.81563918]

[-0.56125996]

[ 1.64255057]

[-1.78762905]

[-0.10637609]

[-1.30160507]]

(6, 5)

w = [[-0.41877003 0.12253528 0.71806184 -0.59645841 0.34434439]

[ 1.84111609 -1.30775289 -1.24070082 0.55306451 -1.94332213]

[ 0.40351707 0.17099953 0.60908037 -0.25724238 -0.32904466]

[-0.15303428 -0.01397066 -0.55895095 -1.47137884 -0.52661826]

[ 0.2487372 -0.88314338 -0.15866652 0.10156308 1.42306267]

[-0.76701692 1.19854978 0.31583967 1.04686157 0.26428875]]

a = [[-0.62313353]

[ 0.83691557]

[ 0.16848985]

[ 1.11337741]

[ 2.09855381]]

result = [[ 1.35866856]

[-6.47444192]

[ 0.65991579]

[-4.54147481]

[ 2.0722289 ]

[ 1.95282584]]

-------------------------------------------------------

===============================================================


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: