Accelerated Functional Programming - The Case for Jax
Description
JAX is an extensible system for composable function transformations. You can read its groundbraking benefits for a wide range of scientific applications here.
You are working on creating a 3h long tutorial (when presented in a usual educational pace) that will teach the rest of the company, about this new technology and persuade them that JAX is the best thing since sliced bread.
Step 1: Select and name your company
Select a name that best represents the domain you want to develop products and solutions for. For example: Optima Labs Inc. Write a one sentence summary of your pitch. For example, “Optima Labs automates the tuning of your models so you don’t have to”.
Usually the domain must be well established but you are free to work on a domain that is not. Examples of well established ML/AI domain include: computer vision and video analytics, natural language processing, etc. Ask the professor if you are uncertain or the team cant make a quick decision.
Step 2: Select one of the Kaggle competitions that is compatible to the scope of your company.
You will find the ongoing competitions here. Read the competition description and store the data in your Gdrive.
Step 3: Select a learning method
You can use the sections/chapters of your course site. For example, Ensemble Learning / Gradient Boosting. Feel free to select other methods outside the scope of the course.
This learning method you selected will be the basis of your tutorial. You will demonstrate the usage of JAX in improving the understanding of your audience about this topic and solve the Kaggle competition problem statement.
For example the team at OptimaLabs selected the Ensemble Learning chapter. They found a library called GPJax to demonstrate a hyperparameter optimization of Gradient Boosting methods (XGBoost). They applied this to a Kaggle competition that could be solved with Gradient Boosting.
As an another example, a company named FakeNews Killer Inc selected the DNN chapter. They found the WikiGraphs that powers one of the components of their fake news detection pipeline. They applied this to the fake news competition
You can find plenty of JAX-based libraries you can use here but searching across Github is also needed as this is not an exhaustive list.
Step 3: Submit your Proposal
Read the project guidelines to make sure you hit all the main points. Your proposal needs to be submitted as Github repo. Add the prof and TA as collaborators.
Step 4: Write the code and the report
You need to demonstrate how functional programming powered by the JAX ecosystem and the report must be written as a tutorial in Colab notebook format. See also the general project guidelines.
Invitation: We particularly interested to receive proposals that combine simulation and learning. JAX has accelerated and greatly simplified tools - for example CFD: https://github.com/google/jax-cfd. See the Heartflow problem statement in lecture 1.