//
// FILE: checksparsematrix.cc
//

// If this program produces any other warning or error messages when
// compiled, it means you need to fix a problem in your matrix class.

using namespace std;

#include "SparseMatrix.hh"

#include <stdlib.h>
#include <sstream>
#include <set>
#include <iostream>

// Reliably reproducible pseudorandom numbers.
int rnd()           { return int(lrand48()); }
int rnd(int mac)    { return int(drand48() * mac); }

// ----------------------------------------------------------------------

// Normally, this declaration would go in a separate header file.

class ErrorContext                          // displays test results
{
public:
    ErrorContext(ostream &os);              // write header to stream
    void desc(const char *msg, int line);   // write line/description
    void desc(string msg, int line);
    void result(bool good);                 // write test result
    ~ErrorContext();                        // write summary info
    bool ok() const;                        // true iff all tests passed
    
private:
    ostream &os;                            // output stream to use
    int passed;                             // # of tests which passed
    int total;                              // total # of tests
    int lastline;                           // line # of most recent test
    set<int> badlines;                      // line #'s of failed tests
    bool skip;                              // skip a line before title?
};

// ----------------------------------------------------------------------

// Normally, these method implementations would go in a separate source file.

ErrorContext::ErrorContext(ostream &os)
    : os(os), passed(0), total(0), lastline(0), skip(false)
{
    os << "line: ";
    os.width(65);
    os.setf(ios::left, ios::adjustfield);
    os << "description" << " result" << endl;
    os.width(78);
    os.fill('~');
    os << "~" << endl;
    os.fill(' ');
    os.setf(ios::right, ios::adjustfield);
}


void 
ErrorContext::desc(const char *msg, int line)
{
    if ((lastline != 0) || ((*msg == '-') && skip))
    {
        os << endl;
    }
    
    os.width(4);
    os << line << ": ";
    os.width(65);
    os.setf(ios::left, ios::adjustfield);
    os << msg << " ";
    os.setf(ios::right, ios::adjustfield);
    os.flush();
    
    lastline = line;
    skip = true;
}


void 
ErrorContext::desc(string msg, int line)
{
    if ((lastline != 0) || ((msg[0] == '-') && skip))
    {
        os << endl;
    }
    
    os.width(4);
    os << line << ": ";
    os.width(65);
    os.setf(ios::left, ios::adjustfield);
    os << msg << " ";
    os.setf(ios::right, ios::adjustfield);
    os.flush();
    
    lastline = line;
    skip = true;
}


#define desc(x) desc(x, __LINE__)

void 
ErrorContext::result(bool good)
{
    if (good)
    {
        os << "ok";
        passed++;
    }
    else
    {
        os << "ERROR";
        badlines.insert(lastline);
    }
    
    os << endl;
    total++;
    lastline = 0;
}


ErrorContext::~ErrorContext()
{
    os << endl << "Passed " << passed << "/" << total << " tests." << endl
       << endl;
    
    if (badlines.size() > 0)
    {
        os << "For more information, please consult:" << endl;
        for (set<int>::const_iterator it = badlines.begin();
             it != badlines.end(); it++)
        {
            os << "  " << __FILE__ << ", line " << *it << endl;
        }
        os << endl;
        
        if (badlines.size() > 2)
        {
            os << "We recommend that you fix the topmost failure "
                  "before going on."
               << endl << endl;
        }
    }
}


bool 
ErrorContext::ok() const
{
    return (passed == total);
}


// ----------------------------------------------------------------------

void 
empty(ErrorContext &ec)
{
    ec.desc("--- Empty matrices (0 by 0) ---");
    
    // If this test fails, it means that getrows returned a non-zero value
    // when called on an "empty" (0x0) matrix created with the default
    // constructor.
    ec.desc("default constructor and getrows");
    
    // Construct an empty matrix using the default constructor.
    const SparseMatrix a;
    
    // Make sure "getrows" method returns zero for this matrix.
    ec.result(a.getrows() == 0);
    
    
    // Same as above, for getcols instead of getrows.
    ec.desc("default constructor and getcols");
    ec.result(a.getcols() == 0);
    
    
    // If this test fails, it means that getrows returned a non-zero value
    // when called on an "empty" (0x0) matrix created with the 2-argument
    // constructor.
    ec.desc("two-argument constructor and getrows");
    
    // Construct an empty matrix using the two-argument constructor.
    const SparseMatrix b(0, 0);
    
    // Make sure "getrows" method returns zero for this matrix.
    ec.result(b.getrows() == 0);
    
    // Same as above, for getcols instead of getrows.
    ec.desc("two-argument constructor and getcols");
    ec.result(b.getcols() == 0);
}


void 
basic(ErrorContext &ec, int level)
{
    const int rows = rnd(level * level) + 1;
    const int cols = rnd(level * level) + 1;

    {
        ostringstream oss;
        oss << "--- Basic operations (" 
            << rows << " by " 
            << cols << " matrix) ---"
            << ends;
        ec.desc(oss.str());
        
        // Make sure getrows and getcols work for non-empty matrices.

        ec.desc("getrows and getcols");
        SparseMatrix a(rows, cols);
        ec.result((a.getrows() == rows) && (a.getcols() == cols));
    }
    
    // Repeat "level" times.

    for (int pass = 1; pass <= level; pass++)
    {
        // Choose a valid matrix position and value.
        const int r = rnd(rows);
        const int c = rnd(cols);
        const int v = rnd();
        
        // Let the user know what we're about to try.
        ostringstream oss;
        oss << "getelem and setelem (1), pass " << pass 
            << ": element (" << r
            << "," << c << ")" << ends;
        ec.desc(oss.str());
        
        // Create a new matrix of the appropriate size.
        SparseMatrix a(rows, cols);
        
        // Set the value.
        a.setelem(r, c, v);
        
        // Verify the value.
        ec.result(a.getelem(r, c) == v);
    }
    
    
    // Repeat "level" times.
    
    for (int pass = 1; pass <= level; pass++)
    {
        {
            // If this part (2a) of the test fails, but part (1) succeeded,
            // you might have forgotten to explicitly initialize the contents
            // of your matrix to all-zeroes in your two-argument constructor.
            // You can do this using a loop which sets each element to zero.
            
            ostringstream oss;
            oss << "getelem and setelem (2a), pass " 
                << pass << " [read]" << ends;
            ec.desc(oss.str());
        }
        
        // Create a new matrix.
        SparseMatrix a(rows, cols);
        
        // Ensure that the matrix initially contains only zeroes.
        bool good = true;

        for (int r = 0; r < rows; r++)
        {
            for (int c = 0; c < cols; c++)
            {
                good &= (a.getelem(r, c) == 0);
            }
        }
        
        ec.result(good);
        
        {
            // A failure here may be due to a bug in getelem, setelem, or
            // your two-argument constructor.  Did you swap the order of the
            // row and column somewhere?  Did you allocate enough space?
            
            ostringstream oss;
            oss << "getelem and setelem (2b), pass " 
                << pass << " [write]" << ends;
            ec.desc(oss.str());
        }
        
        // Fill the matrix with various exciting values.
        for (int r = 0; r < rows; r++)
        {
            for (int c = 0; c < cols; c++)
            {
                a.setelem(r, c, (r - 666 * c) * pass);
            }
        }
        
        // Now read the values back out and see if they're all correct.
        good = true;

        for (int r = 0; r < rows; r++)
        {
            for (int c = 0; c < cols; c++)
            {
                good &= (a.getelem(r, c) == ((r - 666 * c) * pass));
            }
        }
        
        ec.result(good);
    }
}


void 
copy(ErrorContext &ec)
{
    ec.desc("--- Copying matrices ---");


    ec.desc("copy constructor: target dimensions");
    
    // Create a largish matrix and fill in some values.
    SparseMatrix a(123, 456);

    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 456; c++)
        {
            a.setelem(r, c, r * (c - 246));
        }
    }
    
    // Invoke copy constructor.
    const SparseMatrix &ref = a;
    const SparseMatrix b = ref;
    
    // See if the resulting matrix has the correct size.
    ec.result((b.getrows() == 123) && (b.getcols() == 456));
    
    
    ec.desc("copy constructor: target values");
    
    // See if the values were copied correctly.
    bool good = true;
    
    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 123; c++)
        {
            good &= (b.getelem(r, c) == (r * (c - 246)));
        }
    }
    
    ec.result(good);
    
    
    ec.desc("copy constructor: source dimensions and values");
    
    // Verify that the original matrix was not changed by the copy.
    good = (a.getrows() == 123) && (a.getcols() == 456);

    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 123; c++)
        {
            good &= (a.getelem(r, c) == (r * (c - 246)));
        }
    }
    
    ec.result(good);
    
    
    // This next test should always pass, unless a previous test failed or you
    // foolishly ignored a compiler warning about "const" somewhere.
    ec.desc("copy constructor: depth of copy");
    
    // Systematically eradicate the original values.
    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 123; c++)
        {
            a.setelem(r, c, 23);
        }
    }
    
    // Verify that the original matrix is gone but the copy is still there.
    good = (a.getrows() == 123) && (a.getcols() == 456)
        && (b.getrows() == 123) && (b.getcols() == 456);

    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 123; c++)
        {
            good &= ((b.getelem(r, c) == (r * (c - 246)))
                     && (a.getelem(r, c) == 23));
        }
    }
    
    ec.result(good);
    
    
    // If the previous test fails, this next test may appear to succeed, even
    // if the assignment operator fails to work.
    ec.desc("assignment operator: matrices of equal size");
    
    // Copy the values back using the assignment operator.
    a = b;
    
    // Make sure everything is the way it should be.
    good = (a.getrows() == 123) && (a.getcols() == 456)
        && (b.getrows() == 123) && (b.getcols() == 456);

    for (int r = 0; r < 123; r++)
    {
        for (int c = 0; c < 123; c++)
        {
            good &= (a.getelem(r, c) == (r * (c - 246))) 
                 && (b.getelem(r, c) == (r * (c - 246)));
        }
    }
    
    ec.result(good);
    
    
    ec.desc("assignment operator: matrices of differing size");
    
    // Replace b with a zeroed 23x5 matrix.
    a = SparseMatrix(23, 5);
    
    // Verify the new size and contents of a.
    good = (a.getrows() == 23) && (a.getcols() == 5);

    for (int r = 0; r < 23; r++)
    {
        for (int c = 0; c <  5; c++)
        {
            good &= (a.getelem(r, c) == 0);
        }
    }
    
    ec.result(good);
    
    
    // NOTE: This test may fail to detect a common problem!  But we'll try:
    ec.desc("assignment operator: self-assignment");
    
    // Start out by filling the new 23x5 matrix with some nice 1-heavy garbage.
    for (int r = 0; r < 23; r++)
    {
        for (int c = 0; c <  5; c++)
        {
            a.setelem(r, c, ~(r ^ (c << 16)));
        }
    }
    
    // Assign a to itself in various exciting and different ways.
    a = ((a = a = a) = (a = a = a) = (a = a = a)) = a;
    
    // See how ineffective that really was.
    good = (a.getrows() == 23) && (a.getcols() == 5);

    for (int r = 0; r < 23; r++)
    {
        for (int c = 0; c <  5; c++)
        {
            good &= (a.getelem(r, c) == ~(r ^ (c << 16)));
        }
    }
    
    ec.result(good);
}


void 
comp(ErrorContext &ec, int level)
{
    ec.desc("--- Comparison operators ---");


    for (int pass = 1; pass <= (level / 3) * (level / 3); pass++)
    {
        ec.desc("two empty matrices, equality");
        ec.result(SparseMatrix() == SparseMatrix(0, 0));
        
        ec.desc("two empty matrices, inequality");
        ec.result(!(SparseMatrix(0, 0) != SparseMatrix()));
        
        
        const int rows = level + rnd(level);
        const int cols = level + rnd(level * level);

        {
            ostringstream oss;
            oss << "identical " << rows << " by " 
                << cols << " matrices, equality"
                << ends;
            ec.desc(oss.str());
        }
        
        // Create two identical random matrices.
        SparseMatrix a(rows, cols), b(rows, cols);

        for (int r = 0; r < rows; r++)
        {
            for (int c = 0; c < cols; c++)
            {
                int v = rnd();
                a.setelem(r, c, v);
                b.setelem(r, c, v);
            }
        }
        
        ec.result((a == b) && (b == a));
        
        {
            ostringstream oss;
            oss << "identical " << rows << " by " 
                << cols << " matrices, inequality"
                << ends;
            ec.desc(oss.str());
        }
        
        ec.result(!((a != b) || (b != a)));
        
        
        int row = rnd(rows);
        int col = rnd(cols);

        {
            ostringstream oss;
            oss << "matrices differing only at (" << row << "," << col
                << "), equality" << ends;
            ec.desc(oss.str());
        }
        
        // Add 1 to a random element in a.
        a.setelem(row, col, a.getelem(row, col) + 1);
        
        ec.result(!((a == b) || (b == a)));
        
        
        {
            ostringstream oss;
            oss << "matrices differing only at (" << row << "," << col
                << "), inequality" << ends;
            ec.desc(oss.str());
        }
        
        ec.result((a != b) && (b != a));
        
        
        // Ensure that if one dimension is zero, both will be
        if ((row == 0) || (col == 0))
        {
            row = col = 0;
        }
        
        {
            ostringstream oss;
            oss << rows << " by " << cols << " vs. " << row 
                << " by " << col
                << ", equality" << ends;
            ec.desc(oss.str());
        }
        
        // Make b into a "subset" of a.
        b = SparseMatrix(row, col);

        for (int r = 0; r < row; r++)
        {
            for (int c = 0; c < col; c++)
            {
                b.setelem(r, c, a.getelem(r, c));
            }
        }
        
        ec.result(!((a == b) || (b == a)));
        
        
        {
            ostringstream oss;
            oss << rows << " by " << cols << " vs. " << row 
                << " by " << col
                << ", inequality" << ends;
            ec.desc(oss.str());
        }
        
        ec.result((a != b) && (b != a));
    }
}


void
math(ErrorContext &ec, int level)
{
    ec.desc("--- Arithmetic operators ---");

    // Note that these tests depend heavily on previously tested operations!
    // They should not be run unless everything else checks out OK.

    if (!ec.ok())
    {
        ec.desc("one or more previous failures; skipping this section");
        ec.result(false);
        return;
    }

    for (int pass = 1; pass <= (level / 2); pass++)
    {
        // Make up some dimensions for these matrices.
        const int x = (pass > 1) ? (rnd(level) + 1) : 0;
        const int y = (pass > 1) ? (rnd(level) + 1) : 0;
        const int z = (pass > 1) ? (rnd(level) + 1) : 0;
        
        // These will be three random matrices.
        SparseMatrix a(x, y);
        SparseMatrix b(x, y);
        SparseMatrix c(y, z);
        
        // These will be the correct results of arithmetic operations.
        SparseMatrix a_plus_b(x, y);
        SparseMatrix a_minus_b(x, y);
        SparseMatrix a_times_c(x, z);
        
        // Fill in the values and results.
        
        for (int ix = 0; ix < x; ix++)
        {
            for (int iy = 0; iy < y; iy++)
            {
                const int va = rnd();
                const int vb = rnd();
                
                a.setelem(ix, iy, va);
                b.setelem(ix, iy, vb);
                
                a_plus_b.setelem(ix, iy, va + vb);
                a_minus_b.setelem(ix, iy, va - vb);
            }
        }
        
        for (int iy = 0; iy < y; iy++)
        {
            for (int iz = 0; iz < z; iz++)
            {
                const int vc = rnd();
                c.setelem(iy, iz, vc);
                
                for (int ix = 0; ix < x; ix++)
                {
                    a_times_c.setelem(ix, iz, 
                                      a_times_c.getelem(ix, iz) 
                                      + vc * a.getelem(ix, iy));
                }
            }
        }
        
        // Set up read-only copies of the three reference matrices.
        const SparseMatrix copy_a = a;
        const SparseMatrix copy_b = b;
        const SparseMatrix copy_c = c;
        
        // Check non-destructive addition.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") + (" << x 
                << " by " << y << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result(copy_a + copy_b == a_plus_b);
        }
        
        // Ensure arguments were not altered.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") + (" << x 
                << " by " << y << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((copy_a == a) && (copy_b == b));
        }
        
        // Check non-destructive subtraction.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") - (" << x 
                << " by " << y << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result(copy_a - copy_b == a_minus_b);
        }
        
        // Ensure arguments were not altered.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") - (" << x 
                << " by " << y << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((copy_a == a) && (copy_b == b));
        }
        
        // Check non-destructive multiplication.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") * (" << y 
                << " by " << z << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result(copy_a * copy_c == a_times_c);
        }
        
        // Ensure arguments were not altered.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") * (" << y 
                << " by " << z << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((copy_a == a) && (copy_c == c));
        }
        
        // Check destructive addition.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") += (" << x 
                << " by " << y << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result((a_minus_b += copy_b) == copy_a);
        }
        
        // Ensure LHS was altered and RHS was not.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") += (" << x 
                << " by " << y << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((a_minus_b == copy_a) && (copy_b == b));
        }
        
        // Check destructive subtraction.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") -= (" << x 
                << " by " << y << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result((a_plus_b -= b) == copy_a);
        }
        
        // Ensure LHS was altered and RHS was not.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") -= (" << x 
                << " by " << y << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((a_plus_b == copy_a) && (copy_b == b));
        }
        
        // Check destructive multiplication.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") *= (" << y 
                << " by " << z << ")"
                << ", return value" << ends;
            ec.desc(oss.str());
            ec.result((a *= copy_c) == a_times_c);
        }
        
        // Ensure LHS was altered and RHS was not.
        {
            ostringstream oss;
            oss << "(" << x << " by " << y << ") *= (" << y 
                << " by " << z << ")"
                << ", side effects" << ends;
            ec.desc(oss.str());
            ec.result((a == a_times_c) && (copy_c == c));
        }
    }
}


// ----------------------------------------------------------------------

int
main(int argc, const char *argv[])
{
    // Use first argument, if any, as level.
    
    int level = (argc > 1) ? atoi(argv[1]) : 0;
    
    if (level <= 0)
    {
        cout << "Test level may be specified as a "
                "positive integer argument."
             << endl
             << "Larger levels yield more tests; "
                "smaller levels yield fewer tests."
             << endl << endl;
        
        level = 5;
    }
    
    // Set everything up.
    
    cout << "Performing a level " 
         << level << " check of class SparseMatrix." << endl
         << endl;
    ErrorContext ec(cout);
    srand48(level);
    
    // Perform the appropriate checks for this level.
    
    empty(ec);
    
    if (level >= 2)
    {
        basic(ec, level);
    }
    
    if (level >= 3)
    {
        copy(ec);
    }
    
    if (level >= 4)
    {
        comp(ec, level);
    }
    
    if (level >= 5)
    {
        math(ec, level);
    }
    
    // Return 0 (success) if all the checks passed, 1 otherwise.
    
    return (ec.ok() ? 0 : 1);
}

