Multivariable Linear Regression….In Java Part 1

Arnav Kartikeya
4 min readNov 17, 2021

A while back I wrote a blog on linear regression in Octave. Recently I decided to brush up on my skills and really start learning about the math behind machine learning concepts like linear regression. I decided to take on the project of creating my own implementations of machine learning concepts in Java, so that I could truly say that I fully understand those concepts, and so that I could use a language I was more familiar with.

This blog is about the machine learning concept of multivariable linear regression (I am going to call this MLR for the rest of the blog). For Part 1 I will go over the implementation of matrices in Java.v

Matrices: The first step

Matrices are fundamental to linear regression and machine learning in general. They make code run more optimized, shorter, and can deal with variables beyond a hard coded amount. So the first step to our goal of implementing linear regression is to handle matrices.

Matrices can be defined as a set of numbers arranged by row and column. Matrix dimensions are described by height times width, so a 2 by 5 matrix would have 2 rows and 5 columns. The most obvious way to implement this in Java is to use a 2 dimensional array. For the sake of convenience we can say these matrices will be storing double values. So in Java we can create a matrix class which has a constructor that creates a 2 dimensional array of the double type

public class Matrix {
private double[][] matrix;
private int height;
private int width;
public Matrix(double[][] values){
this.height = values.length;
this.width = values[0].length;
this.matrix = values
}

This is the start to our simple matrix class. Next we can add fundamental matrix operations to our class. The operations needed are matrix multiplication, scalar multiplication, matrix addition, and matrix subtraction.

Addition, subtraction, and scalar multiplication are rather simple. For addition and subtraction you simply take the values in the matrix which correspond in height and width and add or subtract those. For scalar multiplication you simply multiply each matrix by the scalar matrix.

public Matrix scalarMultiplication(double scalar) {
double[][] arr = new double[this.getHeight()][this.getWidth()];

for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
arr[i][j] = (double)(this.getElement(i, j) * scalar);
}
}


Matrix ans = new Matrix(arr);
return ans;
}

public Matrix scalarDivision(double scalar) {
double[][] arr = new double[this.getHeight()][this.getWidth()];

for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
arr[i][j] = (double) (this.getElement(i, j) / scalar);
}
}


Matrix ans = new Matrix(arr);
return ans;
}
public Matrix addMatrix(Matrix b) {
if (b.getWidth() != this.getWidth() || b.getHeight() != this.getHeight()) {
System.out.println("Matrices are not the same dimensions");
return null;
}
double[][] arr = new double[this.getHeight()][this.getWidth()];
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
arr[i][j] = this.getElement(i, j) + b.getElement(i, j);
}
}
Matrix ans = new Matrix(arr);
return ans;
}

public Matrix subtractMatrix(Matrix b) {
if (b.getWidth() != this.getWidth() || b.getHeight() != this.getHeight()) {
System.out.println("Matrices are not the same dimensions");
return null;
}
double[][] arr = new double[this.getHeight()][this.getWidth()];
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
arr[i][j] = this.getElement(i, j) - b.getElement(i, j);
}
}
Matrix ans = new Matrix(arr);
return ans;
}

To get the whole idea of the code, I will post a github link which has the entire source code. Methods like getElement are simple enough to understand, it simply retrieves the element at the specified index.

Multiplying matrices are a bit trickier. It is not very efficient either, with a complexity of roughly O(n³). The way matrix multiplication works is to take the first row of the first matrix and the first column of the second matrix. Then, the first element from the first matrix row and the first element from the second matrix column are multiplied together. This happens until all values are multiplied, and afterwards they are summed together. That summed value becomes the first value in the new matrix. The process continues until the last row is multiplied with the last column. The following is the implementation of matrix multiplication:

public Matrix multiplyMatrix(Matrix b) {
double[][] arr = new double[this.getHeight()][b.getWidth()];
for (int row = 0; row < this.getHeight(); row++) {
for (int col = 0; col < b.getWidth(); col++) {
//multiply rowth row with colth column
arr[row][col] = multiplyCells(this, b, row, col);
}
}
Matrix ans = new Matrix(arr);
return ans;
}

public double multiplyCells(Matrix a, Matrix b, int row, int col) {
double val = 0;
for (int i = 0; i < a.getWidth(); i++) {
val += a.getElement(row, i) * b.getElement(i, col);
}
return val;
}

It’s also worth noting for matrix multiplication if the first matrix has a dimension a * b and the second has c * d, c must equal b for matrix multiplication to work. Also the final matrix will have a dimension of a * d.

The last matrix operation for now is the matrix transpose operation, which takes the first row of a matrix and makes it the first column. It does this for all rows, essentially turning the matrix 90 degrees. A matrix with dimensions n * m if transposed would have new dimensions of m * n. Here is the implementation.

public Matrix transpose() {
double[][] arr = new double[this.getWidth()][this.getHeight()];
for (int row = 0; row < this.getHeight(); row++) {
for (int col = 0; col < this.getWidth(); col++) {
arr[col][row] = this.getElement(row, col);
}
}
Matrix ans = new Matrix(arr);
return ans;
}

There are a few more matrix operations that will eventually be needed, but there are the fundamental ones that are truly what matrices do. In later blogs I will discuss the implementation of these methods and what they can be used for. Use this github link to see the whole code in context: https://github.com/arnavkartikeya/MachineLearningJava

--

--

Arnav Kartikeya

A high school student interested in cognitive science and programming