Bayesian Learning Agents#

JAX is an extensible system for composable function transformations. You can read its groundbraking benefits for a wide range of scientific applications here.

Many practitioners prefer functional APIs as compared to the object oriented APIs especially when new concepts are developed (eg Research) or learned (eg Education). In essence at the very minimum you get a substantial perf improvement in performance and at the same time functionality that is very much needed for any differential learner to work (such as automatic differentiation).

The assignment attempts to teach you a couple of things about Bayesian Linear Regression and allow you to develop the method from scratch using … JAX. Please go over the sections 20.2.6-20.2.8 of this textbook to understand principles behind Bayesian Regression.

**Make sure you comply to the JAX API requirement. Make sure you type markdown cells or inline comments describing every line. The notebook must read like a tutorial.

For implementing this assignment you will use JAX from scratch without any other APIs but you are free to look at implementations of Bayesian Regression such as the notebooks given in your course notes. **

Generate and transform data (20 points)#

Generate a sinusoidal dataset of \(m\) data points such as the one we have met in class. The synthetic dataset is generated by the function \(\sin(2 \pi x) + ϵ\) where \(x\) is a uniformly distributed random variable and \(ϵ\) is \(N(\mu=0.0, \sigma^2=0.3)\). Write the code such as you can produce results for the very small (eg 3 data points), low \(m\) (eg 10 data points) and larger \(m\) (eg 100 data points) cases.

Simulate an online learning (streaming system) where data from this dataset are arriving sequentially one example at a time. In the industry, one of the most useful publish-subscribe system is Kafka (and there is a simpler compatible system called RedPanda). Although in practice such systems will be used you can use simply a for loop to read the samples sequentially.

Use Gaussian Radial Basis Function (RBF) features (as compared to the polynomials we have seen in class) and create the hypothesis set. See how Gaussian RBFs look like here:

# Your code goes here - 

Online Maximum Likelihood - Linear Regression (30 points)#

Implement an online maximum likelihood method for the linear regression with SGD.

Online Bayesian Linear Regression (30 points)#

Write the Bayesian linear regression model and plot the predictive distributions as shown of the data as they are received. Write all the math equations that are associated with such figures and explain the functional form of the posterior distribution.

If it is, find a way to show that the Bayesian approach is better in terms of predictions than the maximum likelihood approach obtained earlier and for what data regime (\(m\)).

# Your code goes here

Error Plots (20 points)#

  1. Divide the dataset into train and test.

  2. Obtain the plots of the training and test MSEs vs the model complexity assuming that in the Bayesian Linear Regression we use as regression function the mean of the predictive distribution obtained in the previous question.

# Your code goes here