This commit is contained in:
Leon Wilzer 2022-12-27 18:31:12 +01:00
parent 32b1e69f55
commit e65a4df5f7

View File

@ -1,40 +1,55 @@
#include <exception>
#include<iostream>
#include <iostream>
#include <ostream>
#include <system_error>
#include<vector>
#include<initializer_list>
#include<stdexcept>
#include <vector>
#include <initializer_list>
#include <stdexcept>
#include <type_traits>
#include <chrono>
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value, &operator*>
class matrix
{
public:
matrix(size_t rows, size_t columns) : rows(rows), columns(columns)
{
data();
data = std::vector<T>(rows*columns);
}
matrix(size_t rows, size_t columns, const T &ival) : rows(rows), columns(columns)
{
data = std::vector<T>(rows*columns, ival);
}
matrix(std::initializer_list<std::initializer_list<T>> imat)
{
size_t column_num = 0;
size_t row_num = 0;
size_t prev_length;
bool isNotFirst = false;
for(std::initializer_list<T> row : imat)
{
if(prev_length!=row.size())
if(isNotFirst&&prev_length!=row.size())
{
throw std::domain_error("provided initializer list with empty spots.");
throw std::domain_error("provided initializer list with empty elements.");
}
++row_num;
prev_length = row.size();
for(T t : row)
{
if(!isNotFirst) { ++column_num; }
data.push_back(t);
}
isNotFirst = true;
}
this->rows = row_num;
this->columns = column_num;
}
T& operator() (size_t row, size_t column)
{
@ -67,8 +82,51 @@ class matrix
size_t num_rows() const noexcept { return rows; }
size_t num_columns() const noexcept { return columns; }
std::vector<T> column_to_vec(size_t column) const
{
std::vector<T> out(this->num_rows());
for(size_t i=0; i<this->num_rows();++i)
{
out.push_back((*this)(i,column));
}
return out;
}
std::vector<T> row_to_vec(size_t row) const
{
std::vector<T> out(this->num_columns());
for(size_t i=0; i<this->num_columns();++i)
{
out.push_back((*this)(row,i));
}
return out;
}
static T dot_product(std::vector<T> lhs, std::vector<T> rhs)
{
if(lhs.size()!=rhs.size())
{
throw std::domain_error("vectors must be the same size");
}
T sum = 0;
for(size_t i=0; i<lhs.size(); ++i)
{
sum += lhs.at(i)*rhs.at(i);
}
return sum;
}
friend matrix operator* (const matrix &lhs, const T &scale)
{
if(!std::is_arithmetic<T>())
{
throw std::domain_error("type must be arithmetic.");
}
matrix out(lhs.num_rows(), lhs.num_columns(), 0);
for(size_t x=0; x<lhs.num_columns(); ++x)
{
@ -80,10 +138,25 @@ class matrix
return out;
}
typename std::enable_if<std::is_arithmetic<T>::value>
friend matrix operator* (const matrix &lhs, const matrix &rhs) = delete;
friend matrix operator* (const matrix &lhs, const matrix &rhs)
{
//TODO
throw std::logic_error("unimplemented");
if(!std::is_arithmetic<T>())
{
throw std::domain_error("type must be arithmetic.");
}
if(lhs.num_columns()!=rhs.num_rws(rows), columns(columns)
{
for(size_t y=0; y<out.num_rows(); ++y)
{
out(y,x)=dot_product(lhs.row_to_vec(y), rhs.column_to_vec(x));
}
}
return out;
}
friend bool operator== (const matrix &lhs, const matrix &rhs)
@ -112,10 +185,10 @@ class matrix
}
friend std::ostream& operator<< (std::ostream &os, const matrix &m)
{
for(size_t x=0; x<m.num_columns(); ++x)
{
for(size_t y=0; y<m.num_rows(); ++y)
{
for(size_t x=0; x<m.num_columns(); ++x)
{
os << m(y,x) << '\t';
}
@ -132,8 +205,36 @@ class matrix
int main()
{
matrix<int> m(5,5,5);
std::cout << m;
std::cout << m*5;
std::cout << m*m;
// // TODO comment-in the following code as needed to test your implementation
// // above.
matrix<double> a(3, 3, 3);
a = a * 2;
matrix<double> b(3, 3, 4);
auto start = std::chrono::steady_clock::now();
matrix<double> c = a * b;
auto end = std::chrono::steady_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
std::cout << "multiplied a:\n";
std::cout << a << '\n';
std::cout << "with b:\n";
std::cout << b << '\n';
std::cout << "in " << duration << "ns\n";
std::cout << "result is:\n";
std::cout << c << '\n';
matrix<double> d = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
matrix<double> e = {{9, 8, 7}, {6, 5, 4}, {3, 2, 1}};
start = std::chrono::steady_clock::now();
matrix<double> f = d * e;
end = std::chrono::steady_clock::now();
duration =
std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
std::cout << "multiplied d:\n";
std::cout << d << '\n';
std::cout << "with e:\n";
std::cout << e << '\n';
std::cout << "in " << duration << "ns\n";
std::cout << "result is:\n";
std::cout << f << '\n';
}