Current File : //usr/lib64/python2.7/site-packages/numpy/core/tests/test_blasdot.py
import numpy as np
import sys
from numpy.core import zeros, float64
from numpy.testing import dec, TestCase, assert_almost_equal, assert_, \
     assert_raises, assert_array_equal, assert_allclose, assert_equal
from numpy.core.multiarray import inner as inner_

DECPREC = 14

class TestInner(TestCase):
    def test_vecself(self):
        """Ticket 844."""
        # Inner product of a vector with itself segfaults or give meaningless
        # result
        a = zeros(shape = (1, 80), dtype = float64)
        p = inner_(a, a)
        assert_almost_equal(p, 0, decimal = DECPREC)

try:
    import numpy.core._dotblas as _dotblas
except ImportError:
    _dotblas = None

@dec.skipif(_dotblas is None, "Numpy is not compiled with _dotblas")
def test_blasdot_used():
    from numpy.core import dot, vdot, inner, alterdot, restoredot
    assert_(dot is _dotblas.dot)
    assert_(vdot is _dotblas.vdot)
    assert_(inner is _dotblas.inner)
    assert_(alterdot is _dotblas.alterdot)
    assert_(restoredot is _dotblas.restoredot)


def test_dot_2args():
    from numpy.core import dot

    a = np.array([[1, 2], [3, 4]], dtype=float)
    b = np.array([[1, 0], [1, 1]], dtype=float)
    c = np.array([[3, 2], [7, 4]], dtype=float)

    d = dot(a, b)
    assert_allclose(c, d)

def test_dot_3args():
    np.random.seed(22)
    f = np.random.random_sample((1024, 16))
    v = np.random.random_sample((16, 32))

    r = np.empty((1024, 32))
    for i in xrange(12):
        np.dot(f,v,r)
    assert_equal(sys.getrefcount(r), 2)
    r2 = np.dot(f,v,out=None)
    assert_array_equal(r2, r)
    assert_(r is np.dot(f,v,out=r))

    v = v[:,0].copy() # v.shape == (16,)
    r = r[:,0].copy() # r.shape == (1024,)
    r2 = np.dot(f,v)
    assert_(r is np.dot(f,v,r))
    assert_array_equal(r2, r)

def test_dot_3args_errors():
    np.random.seed(22)
    f = np.random.random_sample((1024, 16))
    v = np.random.random_sample((16, 32))

    r = np.empty((1024, 31))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((1024,))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((32,))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((32, 1024))
    assert_raises(ValueError, np.dot, f, v, r)
    assert_raises(ValueError, np.dot, f, v, r.T)

    r = np.empty((1024, 64))
    assert_raises(ValueError, np.dot, f, v, r[:,::2])
    assert_raises(ValueError, np.dot, f, v, r[:,:32])

    r = np.empty((1024, 32), dtype=np.float32)
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((1024, 32), dtype=int)
    assert_raises(ValueError, np.dot, f, v, r)