Updated On : Jan-24,2022 Time Investment : ~45 mins

Recommender Systems (Collaborative Filtering) using Flax (JAX)

Recommender systems are used to predict the ratings given by a particular user to a particular item. Recommendation systems are commonly used by online websites like Amazon, Netflix, Hotstar, etc. These systems try to understand the taste of the user and then recommend items (movies, books, shopping items, etc.) based on it. Collaborative filtering is one of the ways of implementing recommender systems. In collaborative filtering, we represent the user with a vector of floats and items also with a vector of floats of pre-decided size. Each individual user and item are assigned different vectors. Then, we train our neural network to update these vectors such that when we multiply the user vector by item vector, we get a rating result that is near to the actual rating given by the user to that item. We try to bring predictions near to every user's rating of a particular movie. During this process, we learn about user taste as well as item attributes. Once, we have updated the vector after training the user then we can use it to recommend new items to users that he/she has not rated yet. During training, we update vectors of many users and items together hence generally users with the same taste will be recommended the same items.

As a part of this tutorial, we'll explain how we can implement collaborative filtering using Flax framework of Python. Flax is a deep learning framework designed on top of JAX. Deep learning frameworks generally refers to the vectors that we discussed earlier as embeddings and provides embedding layers that maps user/item ids to their vectors. We expect that reader has a background in Flax/JAX. Please feel free to explore the below links if you want to refresh some Flax/JAX content as it'll help with this tutorial.

We'll be using the Japanese Anime dataset available from kaggle as a part of this tutorial. Please feel free to download it to follow along.

Below, we have listed important sections of our tutorial to give a summary of the content covered in this guide.

Important Sections of Tutorial

Below, we have imported the necessary libraries of this tutorial and printed the versions of them that we have used in this tutorial.

import flax

print("FLAX Version : {}".format(flax.__version__))
FLAX Version : 0.3.6
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.26
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.0

Load and Prepare Data

In this section, we'll be loading ratings and related datasets available from kaggle. We have downloaded the anime dataset from kaggle that we'll be using in our tutorial. We'll load and prepare a dataset for the collaborative filtering task.

The dataset has 2 CSV files.

  • rating.csv - It has information about ratings given by users to animes.
  • anime.csv - It has information about individual anime like their name, genre, average rating, etc.

Load Ratings Data

Below, we have loaded rating data available from rating.csv file. It has user id, anime id, and rating given by the user to that anime. There are entries in the dataframe where the user has watched anime but has not given a rating hence those entries have ratings as -1. We have removed those entries from our data as we want only entries where the user has given a rating.

import pandas as pd
import numpy as np
import gc

ratings_df = pd.read_csv("/kaggle/input/anime-recommendations-database/rating.csv")

ratings_df = ratings_df[ratings_df.rating!=-1]

print("Data : {}".format(ratings_df.shape))

ratings_df.head()
Data : (6337241, 3)
user_id anime_id rating
47 1 8074 10
81 1 11617 10
83 1 11757 10
101 1 15451 10
153 2 11771 10

Load Anime Data

Below, we have loaded anime data from anime.csv file. It has a mapping from anime id to anime name, genre, type, episodes, average rating, and members. We have printed the first few rows to show a glimpse of the dataset.

anime_df = pd.read_csv("/kaggle/input/anime-recommendations-database/anime.csv")

print("Anime : {}".format(anime_df.shape))

anime_df.head()
Anime : (12294, 7)
anime_id name genre type episodes rating members
0 32281 Kimi no Na wa. Drama, Romance, School, Supernatural Movie 1 9.37 200630
1 5114 Fullmetal Alchemist: Brotherhood Action, Adventure, Drama, Fantasy, Magic, Mili... TV 64 9.26 793665
2 28977 Gintama° Action, Comedy, Historical, Parody, Samurai, S... TV 51 9.25 114262
3 9253 Steins;Gate Sci-Fi, Thriller TV 24 9.17 673572
4 9969 Gintama' Action, Comedy, Historical, Parody, Samurai, S... TV 51 9.16 151266

Merge Ratings and Anime Data

In this section, we have merged the rating dataset with the anime dataset using the pandas merge() method. We have merged datasets on column anime_id which is common between the two dataframes. After merging, the dataset has two columns giving rating information hence we have renamed one from the rating dataframe as user rating and one from the anime dataframe as average anime rating.

ratings_df = ratings_df.merge(anime_df, on="anime_id")

ratings_df = ratings_df.rename(columns={"rating_x":"user_rating", "rating_y": "average_anime_rating"})

ratings_df.head()
user_id anime_id user_rating name genre type episodes average_anime_rating members
0 1 8074 10 Highschool of the Dead Action, Ecchi, Horror, Supernatural TV 12 7.46 535892
1 3 8074 6 Highschool of the Dead Action, Ecchi, Horror, Supernatural TV 12 7.46 535892
2 5 8074 2 Highschool of the Dead Action, Ecchi, Horror, Supernatural TV 12 7.46 535892
3 12 8074 6 Highschool of the Dead Action, Ecchi, Horror, Supernatural TV 12 7.46 535892
4 14 8074 6 Highschool of the Dead Action, Ecchi, Horror, Supernatural TV 12 7.46 535892

Few Sanity Checks

In this section, we are calculating counts of unique items and user ids. We are also checking whether there are any Null values present in the user id, anime id, or rating.

n_items = len(ratings_df['anime_id'].unique())
n_users = len(ratings_df['user_id'].unique())
ratings = ratings_df["user_rating"].unique()

print("Unique Items : {}".format(n_items))
print("Unique Users : {}".format(n_users))
print("Ratings : {}".format(ratings))
Unique Items : 9926
Unique Users : 69600
Ratings : [10  6  2  7  9  8  4  5  3  1]
print("Are there Null User IDs ? : {}".format(np.alltrue(np.isnan(ratings_df["user_id"].isnull().values))))
print("Are there Null ISBN ?     : {}".format(np.alltrue(ratings_df["anime_id"].isnull().values)))
print("Are there Null Ratings ?  : {}".format(np.alltrue(ratings_df["user_rating"].isnull().values)))
Are there Null User IDs ? : False
Are there Null ISBN ?     : False
Are there Null Ratings ?  : False

Separate Users With Only One Review and Remove Them From Data

In this section, we are just looping through all unique user ids and separating user who has given a rating to only one anime from those who has given a rating to more than one animes. We have then filtered our final dataframe and kept only user ids who have given ratings to more than one anime. The reason behind performing this step is that, when we'll divide the dataset into train and test sets, we'll be doing separation based on user ids. If we keep user ids with only ratings then we'll have to keep them in the train set and there won't be an entry for them in the test set. As we have removed entries with one rating, we'll have each user id present in both train and test sets. It's a choice that needs to be made. If the reader wants to keep user ids with only one review then it's fine.

from collections import Counter

single_review_users, multiple_review_users = [], []
less_review_users, more_review_users = [], []
for user_id, cnt in Counter(ratings_df["user_id"].values).items():
    if cnt == 1:
        less_review_users.append(user_id)
    else:
        more_review_users.append(user_id)

print("Users with only single review : {}".format(len(less_review_users)))
print("Users with multiple reviews   : {}".format(len(more_review_users)))

ratings_df = ratings_df[ratings_df["user_id"].isin(more_review_users)]
Users with only single review : 3249
Users with multiple reviews   : 66351

Shuffle Merged (Rating + Anime) Data

In this section, we are simply shuffling the dataframe using the pandas' data frames shuffle() method. We have also created dictionaries that map anime ids to index and their title as they'll be useful in the future. We have then also performed checks to see how many user ids, anime ids, and unique rating types are present after we removed entries with a single review.

unique_items = ratings_df["anime_id"].unique()
item_to_idx = dict(zip(unique_items,range(len(unique_items))))
item_to_title = dict(list(zip(ratings_df["anime_id"].values, ratings_df["name"].values)))

#ratings_df = ratings_df.reset_index()
ratings_df = ratings_df.sample(frac=1.0,random_state=123)

ratings_df.head()
user_id anime_id user_rating name genre type episodes average_anime_rating members
6299460 49930 3396 5 Gloria: Kindan no Ketsuzoku Drama, Hentai, Mystery, Romance OVA 3 5.61 654
5225598 71792 31229 6 Servamp Action, Comedy, Drama, Josei, Supernatural, Va... TV 12 7.12 73126
1689446 24635 13403 7 Inu x Boku SS Special Comedy, Shounen, Supernatural Special 1 7.79 46803
5266404 17026 3269 5 .hack//G.U. Trilogy Action, Fantasy, Game, Sci-Fi Movie 1 7.32 22537
6159762 65836 1085 4 Interlude Adventure, Horror, Mystery OVA 3 6.67 7573
n_items = len(ratings_df['anime_id'].unique())
n_users = len(ratings_df['user_id'].unique())
ratings = ratings_df["user_rating"].unique()

print("Unique Items : {}".format(n_items))
print("Unique Users : {}".format(n_users))
print("Ratings : {}".format(ratings))
Unique Items : 9926
Unique Users : 66351
Ratings : [ 5  6  7  4  9 10  8  3  1  2]

Pivot Merged Data to Analyze User-Rating Relation

In this section, we have just pivoted the first 10 entries of our dataframe to a tabular view of ratings given by users to particular anime.

pivoted_df = pd.pivot(ratings_df.head(10), index="user_id", columns="name", values="user_rating")

pivoted_df.head(10)
name .hack//G.U. Trilogy Bleach Movie 1: Memories of Nobody Clannad Dragon Ball Z Gloria: Kindan no Ketsuzoku Interlude Inu x Boku SS Special Kotoura-san Kuroshitsuji Servamp
user_id
6931 NaN NaN NaN NaN NaN NaN NaN 5.0 NaN NaN
17026 5.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN
24635 NaN NaN NaN NaN NaN NaN 7.0 NaN NaN NaN
36815 NaN NaN NaN 9.0 NaN NaN NaN NaN NaN NaN
40945 NaN 10.0 NaN NaN NaN NaN NaN NaN NaN NaN
44941 NaN NaN NaN NaN NaN NaN NaN NaN 10.0 NaN
49930 NaN NaN NaN NaN 5.0 NaN NaN NaN NaN NaN
65836 NaN NaN NaN NaN NaN 4.0 NaN NaN NaN NaN
70449 NaN NaN 9.0 NaN NaN NaN NaN NaN NaN NaN
71792 NaN NaN NaN NaN NaN NaN NaN NaN NaN 6.0

Please make a NOTE that we are calling gc.collect() a few times in our tutorial as it initiates Python garbage collector which frees up memory.

gc.collect()
63

Split Data into Train/Test Sets Based on User IDs

In this section, we have created a small function that divides the dataset into train and test sets. For our case, the data features are user id and anime id and target values are ratings. We'll be feeding user id and anime id to our neural network and it'll predict the rating given by that user to the anime.

In order to divide the dataset into train and test sets, we have first taken unique user ids. Then, we are looping through each user id, taking all entries of that user id (all ratings given) dividing them into the train (90% entries), and test (10%) sets. We have also handled the case where there is only one rating per user then we keep it in train set though we have removed entries with single ratings. If the reader keeps entries with single ratings then the below code will handle it as well.

Our datasets (X_train, X_test) have two entries (user id, anime id). The target values (Y_train, Y_test) are ratings.

def train_test_split(df):
    unique_users = np.unique(df["user_id"].values)
    X_train, X_test, Y_train, Y_test = [], [], [], []
    for i, user_id in enumerate(unique_users):
        ratings_temp = df[df["user_id"] == user_id]
        if ratings_temp.shape[0]==1: ## If only one sample per user then give it to train set
            X_train.extend(ratings_temp[["user_id","anime_id"]].values.tolist())
            Y_train.append(ratings_temp["user_rating"].values[0])
        else:
            idx = int(ratings_temp.shape[0]* 0.9) ## 90% train and 10% test
            ## Populate train data
            X_train.extend(ratings_temp[["user_id","anime_id"]].values[:idx].tolist())
            Y_train.extend(ratings_temp["user_rating"].values[:idx].tolist())
            ## Populate test data
            X_test.extend(ratings_temp[["user_id","anime_id"]].values[idx:].tolist())
            Y_test.extend(ratings_temp["user_rating"].values[idx:].tolist())

        if (i+1)%10000==0:
            print("{} users completed.".format(i+1))

    return np.array(X_train), np.array(X_test), np.array(Y_train), np.array(Y_test)

%time X_train, X_test, Y_train, Y_test = train_test_split(ratings_df)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
10000 users completed.
20000 users completed.
30000 users completed.
40000 users completed.
50000 users completed.
60000 users completed.
CPU times: user 17min 35s, sys: 12.9 s, total: 17min 48s
Wall time: 7min 54s
((5669769, 2), (664221, 2), (5669769,), (664221,))

Few Sanity Checks

Below, we have again performed simple checks for the count of user ids that are present in train but not in test and vice-versa. This is done to verify that the above code properly divided data into train and test sets.

train_users = set(np.unique(X_train[:,0]))
test_users = set(np.unique(X_test[:,0]))

train_not_test = train_users.difference(test_users)
test_not_train = test_users.difference(train_users)

print("Users in Train but not Test : {}".format(len(train_not_test)))
print("Users in Test  but not Train : {}".format(len(test_not_train)))
Users in Train but not Test : 0
Users in Test  but not Train : 0

Model 1: Embeddings

In this section, we'll design our neural network using embeddings. Embeddings are weights that map ids (user id/ item id) to their vector values. As we had discussed earlier, in collaborating filtering, we mentioned that we keep a vector of floats for each unique user and item. Let's say there are 10 unique users and we are maintaining a vector of 50 floats per user then we'll have 10 such vectors of 50 floats, one for each user. Same way vectors will be maintained for items (movies, animes, books, etc). These vectors are generally referred to as embeddings. These embeddings for user generally represents the taste of that user (he/she likes action, adventure movies or comedy movies, or horror movies). The same way embeddings for items represent the characteristics of that item (whether the movie is horror or action or adventure.). The embedding layer for user ids in the network has embeddings for each user and it then maps the user id to its embeddings. The same goes for the embedding layer for item ids. It just maps user/item id to their embeddings (vector of floats). We generally decide the size of these vectors from the beginning and it'll be the same for all user ids and item ids.

Create Model

In this section, we have defined our simple network for the collaborative filtering task. We have set vector size for each user and item to 50 which means that each individual user and item will be represented with 50 floats. The number of floats is generally referred to as factors and is one of the network hyperparameters that needs to be tuned. We had tried different values for factors like 40, 50, 70, 100, and 120. Then, we decided on 50 as our final value.

We have created a model by extending flax.linen.Module class. The setup() method has layers initialized in them and call() has actual forward pass through the network. We have declared two embedding layers (one for users and one for items) using Embed layer of flax.linen. The user embedding layer has a float array of shape (n_users, n_factors) which will map the user id to its vector of floats. The item embedding layer has a float array of shape (n_items, n_factors) which will map item id to its vector of floats.

The forward pass-through network (call() method) simply multiplies user embeddings by item embeddings to generate a rating. The network will take as input tuple of (user_id, item_id), it'll then take the vector for that user id and item id, multiplies them, and sum up to predict rating given by the user to item. It can do this for a batch of user ids and item ids. We then update the weights of users and items to improve rating prediction during the training process.

If the process of network creation and training seems a little hard to grasp then we suggest that the reader goes through our tutorial on Flax that explains how to create neural networks as it'll help with this one as well.

Why We Are Treating Rating Prediction Task As A Regression Task?

We have treated the process of predicting rating as REGRESSION task even though ratings are unique with an integer in the range [1,10]. The reason behind treating the rating prediction task as regression is that we want a model to predict a rating that will be near the actual rating. We can then round up float prediction to an integer. The classification task treats all classes as independent of each other and it does not enforce ORDERING in target classes. It can predict any of the possible classes in case of wrong predictions whereas if we treat the task as a regression task then it'll predict a rating that is near to the actual rating even in case of wrong predictions. In our case, the ratings have ordering where a rating of 2 is better than 3, 3 is better than 4, etc. We can enforce ordering by treating the task as REGRESSION which will try to predict a rating that is near to actual.

Please make a NOTE that all of our models will approach the rating prediction task as REGRESSION task even though ratings are independent classes and seem like CLASSIFICATION task. The reason being we can't enforce an order by treating it as a classification task.

from flax import linen

n_factors = 50

class SimpleRecSystem(linen.Module):
    n_users = n_users
    n_items = n_items
    n_factors = n_factors

    def setup(self):
        self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
        self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")

    def __call__(self, X_batch):
        users = self.user_embeddings(X_batch[:,0])
        items = self.item_embeddings(X_batch[:,1])
        return (users * items).sum(axis=1)

Below, we have initialized the network and printed the shape of embeddings. We have also performed a forward pass-through network for verification purposes.

from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

rec_system = SimpleRecSystem()

params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights = layer_params[1]["embedding"]
    print("\tLayer Weights : {}".format(weights.shape))
Layer Name : User Embeddings
	Layer Weights : (66351, 50)
Layer Name : Item Embeddings
	Layer Weights : (9926, 50)
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

preds.shape
(100,)

Define Loss Function

In this section, we have defined the loss function that we have used in our case. We have used mean squared error loss as our loss which tries to reduce the distance between actual rating and predicted rating. We have also defined mean absolute error as well as that can also be used as a loss function for regression tasks.

def MSELoss(params, input_data, actual):
    preds = rec_system.apply(params, input_data)
    return jnp.sqrt(jnp.power(actual.squeeze() - preds.squeeze(), 2).mean())

def MAELoss(params, input_data, actual):
    preds = rec_system.apply(params, input_data)
    return jnp.abs(actual.squeeze() - preds.squeeze()).mean()

Train Model

In this section, we have trained our neural network of embeddings. We have created a simple function that we have used for training. The function simply loops through train data in batches. For each batch (a bunch of (user id, item id)), it predicts ratings, calculates loss, calculates gradients, and updates weights using gradients. It also calculates the loss of our test dataset which we have provided as a validation set. The function prints both train and validation loss after each epoch. At last, the function returns updated network parameters as well.

from jax import value_and_grad
from tqdm import tqdm

def TrainModel(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=256):
    for i in range(1, epochs+1):
            batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

            losses = [] ## Record loss of each batch
            for batch in tqdm(batches):
                if batch != batches[-1]:
                    start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
                else:
                    start, end = int(batch*batch_size), None

                X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

                loss, gradients = value_and_grad(MSELoss)(params, X_batch,Y_batch) ## Forward Pass and Loss Calculation

                ## Update Weights
                updates, optimizer_state = optimizer.update(gradients, optimizer_state)
                params = optax.apply_updates(params, updates)

                losses.append(loss) ## Record Loss

            print("Train MSELoss : {:.3f}".format(jnp.array(losses).mean()))
            val_loss = MSELoss(params, X_val, Y_val)
            print("Valid MSELoss : {:.3f}".format(val_loss))
            gc.collect()
    return params

Below, we are actually training our network using a function defined above. We have set the number of epochs to 10, batch size to 10,000, and learning rate to 0.001. We have then initialized the network and its weights. Then, We have initialized Adam optimizer to update network parameters in our case. At last, we have called our function to train network.

We can notice from the loss value getting printed after each epoch that our model seems to be improving after each epoch. The last train and validation loss are around 1.4 which indicates that our rating predictions are generally off by 1.4 from the actual rating which is near to the actual rating for the majority of samples.

Later on, we'll try more models to bring down this gap further low (below 1.4).

seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001

rec_system = SimpleRecSystem()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)

final_params = TrainModel(X_train, Y_train, X_test, Y_test, epochs, params, optimizer_state, batch_size)
100%|██████████| 567/567 [01:05<00:00,  8.67it/s]
Train MSELoss : 7.910
Valid MSELoss : 7.828
100%|██████████| 567/567 [00:58<00:00,  9.70it/s]
Train MSELoss : 7.530
Valid MSELoss : 7.239
100%|██████████| 567/567 [00:58<00:00,  9.74it/s]
Train MSELoss : 6.946
Valid MSELoss : 6.583
100%|██████████| 567/567 [01:00<00:00,  9.36it/s]
Train MSELoss : 6.146
Valid MSELoss : 5.372
100%|██████████| 567/567 [01:00<00:00,  9.32it/s]
Train MSELoss : 4.741
Valid MSELoss : 3.420
100%|██████████| 567/567 [01:00<00:00,  9.38it/s]
Train MSELoss : 2.861
Valid MSELoss : 1.903
100%|██████████| 567/567 [01:02<00:00,  9.04it/s]
Train MSELoss : 1.772
Valid MSELoss : 1.517
100%|██████████| 567/567 [01:00<00:00,  9.37it/s]
Train MSELoss : 1.452
Valid MSELoss : 1.456
100%|██████████| 567/567 [01:01<00:00,  9.16it/s]
Train MSELoss : 1.418
Valid MSELoss : 1.478
100%|██████████| 567/567 [01:02<00:00,  9.07it/s]
Train MSELoss : 1.440
Valid MSELoss : 1.459
gc.collect()
21

Predict Ratings

In this section, we are using our trained model to make a prediction of ratings for existing user id. We have designed a function for this. The function randomly selects one user id. It then makes predictions for items that are reviewed by this user id using our trained model. We have then merged actual rating and predicted rating in a single dataframe along with other detail which can be helpful. We have also included logic that converts float ratings to integer classes. We'll be reusing this function in our upcoming sections when we try different models.

import math
def make_predictions(rng, rec_system, final_params):
    user_id = rng.choice(more_review_users, size=1)[0]

    cols_of_interest = ["user_id","anime_id","user_rating","name","genre","type","episodes","average_anime_rating","members"]
    user_data = ratings_df[ratings_df["user_id"] == user_id][cols_of_interest]
    preds = rec_system.apply(final_params, user_data[["user_id","anime_id"]].values)
    preds = jnp.clip(preds,1, 10.5)
    preds = np.array([math.ceil(pred) if (pred-int(pred)) > 0.5 else math.floor(pred) for pred in preds])
    user_data["pred_rating"] = preds
    cols_of_interest.insert(3, "pred_rating")
    return user_data[cols_of_interest], user_id

rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
user_id anime_id user_rating pred_rating name genre type episodes average_anime_rating members
2401263 19880 2904 10 9 Code Geass: Hangyaku no Lelouch R2 Action, Drama, Mecha, Military, Sci-Fi, Super ... TV 25 8.98 572888
497418 19880 19815 10 8 No Game No Life Adventure, Comedy, Ecchi, Fantasy, Game, Super... TV 12 8.47 602291
39491 19880 11757 10 8 Sword Art Online Action, Adventure, Fantasy, Game, Romance TV 25 7.83 893100
2331510 19880 1575 10 9 Code Geass: Hangyaku no Lelouch Action, Mecha, Military, School, Sci-Fi, Super... TV 25 8.83 715151
626895 19880 28121 10 8 Dungeon ni Deai wo Motomeru no wa Machigatteir... Action, Adventure, Comedy, Fantasy, Romance TV 13 7.88 336349
2095991 19880 23233 10 8 Shinmai Maou no Testament Action, Demons, Ecchi, Fantasy, Harem, Romance TV 12 7.11 172321
4270539 19880 31478 10 8 Bungou Stray Dogs Action, Comedy, Mystery, Seinen, Supernatural TV 12 7.76 187805
4263788 19880 27991 10 8 K: Return of Kings Action, Super Power, Supernatural TV 13 7.82 114904
2006895 19880 20785 10 8 Mahouka Koukou no Rettousei Magic, Romance, School, Sci-Fi, Supernatural TV 26 7.76 285317
5603987 19880 856 10 8 Utawarerumono Action, Drama, Fantasy, Sci-Fi TV 26 7.78 91034
3639572 19880 23283 10 8 Zankyou no Terror Psychological, Thriller TV 11 8.26 342893
1465635 19880 9253 10 9 Steins;Gate Sci-Fi, Thriller TV 24 9.17 673572
3682791 19880 28623 10 8 Koutetsujou no Kabaneri Action, Drama, Fantasy, Horror TV 12 7.39 253027
2662790 19880 22663 10 8 Seiken Tsukai no World Break Action, Fantasy, Harem, Romance, School, Super... TV 12 7.16 97401
302681 19880 6702 10 8 Fairy Tail Action, Adventure, Comedy, Fantasy, Magic, Sho... TV 175 8.22 584590

Recommend New Items

In this section, we are recommending new items to existing users. As our users have not used all items, we can recommend new items to them by using our trained model which knows a little about the taste of users. We have created a small function to recommend new items to users.

The function takes as input user id, user data, recommender model, and model parameters. We then retrieve item ids that are not rated by the input user. Then, we use our trained model to predict the rating for those unrated items and sort them from high rating to low. At last, we recommend items for which our model predicted a high rating. The model tries to predict the rating that the user would have given to an item if he had rated it.

def recommend_new_items(user_id, user_data, rec_system, final_params):
    total_item_ids = ratings_df["anime_id"].unique().tolist()
    item_ids_user_watched = user_data["anime_id"].values.tolist()

    item_ids_not_watched_by_user = list(set(total_item_ids).difference(item_ids_user_watched))
    data = np.stack(([user_id]*len(item_ids_not_watched_by_user),item_ids_not_watched_by_user), axis=1)
    preds = rec_system.apply(final_params, data)
    preds = np.array([math.ceil(pred) if (pred-int(pred)) > 0.5 else math.floor(pred) for pred in preds])
    preds = preds.reshape(-1,1)
    recommendation_df = pd.DataFrame(np.hstack((data,preds)),columns=["user_id","anime_id","pred_rating"])
    recommendation_df = recommendation_df.sort_values(by=["pred_rating"],ascending=False)
    cols_of_interest = ["anime_id","name","genre","type","episodes","average_anime_rating","members"]
    recommendation_df = recommendation_df.merge(ratings_df[cols_of_interest].drop_duplicates(), on="anime_id")

    return recommendation_df

recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
user_id anime_id pred_rating name genre type episodes average_anime_rating members
0 19880 7785 9 Yojouhan Shinwa Taikei Mystery, Psychological, Romance TV 11 8.65 122531
1 19880 4282 9 Kara no Kyoukai 5: Mujun Rasen Action, Drama, Mystery, Romance, Supernatural,... Movie 1 8.68 111074
2 19880 572 9 Kaze no Tani no Nausicaä Adventure, Fantasy Movie 1 8.47 143273
3 19880 44 9 Rurouni Kenshin: Meiji Kenkaku Romantan - Tsui... Action, Drama, Historical, Martial Arts, Roman... OVA 4 8.83 129307
4 19880 263 9 Hajime no Ippo Comedy, Drama, Shounen, Sports TV 75 8.83 157670
5 19880 777 9 Hellsing Ultimate Action, Horror, Military, Seinen, Supernatural... OVA 10 8.59 297454
6 19880 57 9 Beck Comedy, Drama, Music, Shounen, Slice of Life TV 26 8.40 148328
7 19880 329 9 Planetes Drama, Romance, Sci-Fi, Seinen, Space TV 26 8.38 105044
8 19880 3901 9 Baccano! Specials Action, Comedy, Historical, Mystery, Seinen, S... Special 3 8.29 100412
9 19880 245 9 Great Teacher Onizuka Comedy, Drama, School, Shounen, Slice of Life TV 43 8.77 268487

How to Recommend Items to New Users?

The very obvious question that can come to an individual's mind when using the collaborative filtering approach for a recommendation system is how to handle new users. In order to handle new users, we need to update our embeddings with embeddings of new users. We also need to rerun our model with embeddings of new users as initially, all embeddings will be random for them. We can train the model daily if we are getting many new users or we can train the model weekly, bi-weekly, or even monthly if the frequency of addition of new users is less. Initially, we can recommend items that have overall high ratings given by the majority of users and once the user has watched/rated other items, we can update their embeddings by training network.

How to Recommend New Items to Users?

Another obvious question that comes to mind is how to handle new items that are new to the platform. In addition, new items will have random embeddings just like new users. One of the solutions can be that we keep on recommending new items randomly to users until we have few ratings on that item. We can then update embeddings of the new item with those few ratings by training our network. We can initially highlight new addition items above to the home pages of websites as done by famous platforms like Netflix, Amazon, etc. We can also send a notification to users who have used the item but have not been given a rating.

Model 2: Embeddings + Bias

In this section, we have introduced bias to embeddings so that it can add one single number to each item/user that can be used to represent more information. It'll also work as a regularization term and help weights settle better.

Create Model

Our model for this section has much of the code the same as an earlier section with few additions. We have declared user and item bias using Embed layer in setup() method. In call() method, we have first multiplied user embeddings by item embeddings and then added user and item biases to it. At last, we are returning results.

from flax import linen

n_factors = 50

class SimpleRecSystemWithBias(linen.Module):
    n_users = n_users
    n_items = n_items
    n_factors = n_factors

    def setup(self):
        self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
        self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
        self.user_bias = linen.Embed(self.n_users, 1, name="User Bias")
        self.item_bias = linen.Embed(self.n_items, 1, name="Item Bias")

    def __call__(self, X_batch):
        users = self.user_embeddings(X_batch[:,0])
        items = self.item_embeddings(X_batch[:,1])
        result = (users * items).sum(axis=1)
        result += self.user_bias(X_batch[:,0]).squeeze() + self.item_bias(X_batch[:,1]).squeeze()
        return result
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

rec_system = SimpleRecSystemWithBias()

params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights = layer_params[1]["embedding"]
    print("\tLayer Weights : {}".format(weights.shape))
Layer Name : User Embeddings
	Layer Weights : (66351, 50)
Layer Name : Item Embeddings
	Layer Weights : (9926, 50)
Layer Name : User Bias
	Layer Weights : (66351, 1)
Layer Name : Item Bias
	Layer Weights : (9926, 1)
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

preds.shape
(100,)

Train Model

In this section, we are training our new recommendation system with bias. We are a training network for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. From the loss value getting printed at the end of epochs, we can notice that this model has seem to have done little better than our earlier model.

seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001

rec_system = SimpleRecSystemWithBias()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)

final_params = TrainModel(X_train,Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
100%|██████████| 567/567 [01:32<00:00,  6.10it/s]
Train MSELoss : 7.743
Valid MSELoss : 7.396
100%|██████████| 567/567 [01:35<00:00,  5.95it/s]
Train MSELoss : 6.890
Valid MSELoss : 6.387
100%|██████████| 567/567 [01:36<00:00,  5.87it/s]
Train MSELoss : 5.963
Valid MSELoss : 5.396
100%|██████████| 567/567 [01:34<00:00,  6.01it/s]
Train MSELoss : 4.853
Valid MSELoss : 3.952
100%|██████████| 567/567 [01:33<00:00,  6.05it/s]
Train MSELoss : 3.360
Valid MSELoss : 2.353
100%|██████████| 567/567 [01:35<00:00,  5.97it/s]
Train MSELoss : 2.110
Valid MSELoss : 1.652
100%|██████████| 567/567 [01:36<00:00,  5.86it/s]
Train MSELoss : 1.568
Valid MSELoss : 1.460
100%|██████████| 567/567 [01:39<00:00,  5.73it/s]
Train MSELoss : 1.407
Valid MSELoss : 1.425
100%|██████████| 567/567 [01:34<00:00,  6.00it/s]
Train MSELoss : 1.385
Valid MSELoss : 1.428
100%|██████████| 567/567 [01:34<00:00,  6.02it/s]
Train MSELoss : 1.395
Valid MSELoss : 1.412
gc.collect()
21

Predict Ratings

In this section, we are making predictions for items that the user has seen and rated. From the results, it seems that this model has done almost the same job as our previous model.

rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
user_id anime_id user_rating pred_rating name genre type episodes average_anime_rating members
2401263 19880 2904 10 9 Code Geass: Hangyaku no Lelouch R2 Action, Drama, Mecha, Military, Sci-Fi, Super ... TV 25 8.98 572888
497418 19880 19815 10 7 No Game No Life Adventure, Comedy, Ecchi, Fantasy, Game, Super... TV 12 8.47 602291
39491 19880 11757 10 7 Sword Art Online Action, Adventure, Fantasy, Game, Romance TV 25 7.83 893100
2331510 19880 1575 10 8 Code Geass: Hangyaku no Lelouch Action, Mecha, Military, School, Sci-Fi, Super... TV 25 8.83 715151
626895 19880 28121 10 7 Dungeon ni Deai wo Motomeru no wa Machigatteir... Action, Adventure, Comedy, Fantasy, Romance TV 13 7.88 336349
2095991 19880 23233 10 7 Shinmai Maou no Testament Action, Demons, Ecchi, Fantasy, Harem, Romance TV 12 7.11 172321
4270539 19880 31478 10 7 Bungou Stray Dogs Action, Comedy, Mystery, Seinen, Supernatural TV 12 7.76 187805
4263788 19880 27991 10 7 K: Return of Kings Action, Super Power, Supernatural TV 13 7.82 114904
2006895 19880 20785 10 7 Mahouka Koukou no Rettousei Magic, Romance, School, Sci-Fi, Supernatural TV 26 7.76 285317
5603987 19880 856 10 7 Utawarerumono Action, Drama, Fantasy, Sci-Fi TV 26 7.78 91034
3639572 19880 23283 10 7 Zankyou no Terror Psychological, Thriller TV 11 8.26 342893
1465635 19880 9253 10 9 Steins;Gate Sci-Fi, Thriller TV 24 9.17 673572
3682791 19880 28623 10 7 Koutetsujou no Kabaneri Action, Drama, Fantasy, Horror TV 12 7.39 253027
2662790 19880 22663 10 7 Seiken Tsukai no World Break Action, Fantasy, Harem, Romance, School, Super... TV 12 7.16 97401
302681 19880 6702 10 8 Fairy Tail Action, Adventure, Comedy, Fantasy, Magic, Sho... TV 175 8.22 584590

Recommend New Items

In this section, we are recommending new items to the user.

recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
user_id anime_id pred_rating name genre type episodes average_anime_rating members
0 19880 4565 9 Tengen Toppa Gurren Lagann Movie: Lagann-hen Action, Mecha, Sci-Fi, Space, Super Power Movie 1 8.64 82253
1 19880 918 9 Gintama Action, Comedy, Historical, Parody, Samurai, S... TV 201 9.04 336376
2 19880 329 9 Planetes Drama, Romance, Sci-Fi, Seinen, Space TV 26 8.38 105044
3 19880 820 9 Ginga Eiyuu Densetsu Drama, Military, Sci-Fi, Space OVA 110 9.11 80679
4 19880 5114 9 Fullmetal Alchemist: Brotherhood Action, Adventure, Drama, Fantasy, Magic, Mili... TV 64 9.26 793665
5 19880 7311 9 Suzumiya Haruhi no Shoushitsu Comedy, Mystery, Romance, School, Sci-Fi, Supe... Movie 1 8.81 240297
6 19880 4282 8 Kara no Kyoukai 5: Mujun Rasen Action, Drama, Mystery, Romance, Supernatural,... Movie 1 8.68 111074
7 19880 4280 8 Kara no Kyoukai 4: Garan no Dou Action, Mystery, Supernatural, Thriller Movie 1 8.05 103600
8 19880 1559 8 Shijou Saikyou no Deshi Kenichi Action, Comedy, Martial Arts, School, Shounen TV 50 8.25 129112
9 19880 6746 8 Durarara!! Action, Mystery, Supernatural TV 24 8.38 556431

Model 3: Embeddings + Dense Layers

Both of our previous models used only embedding layers and operations between them. But in this example, we'll extend the concept and add dense layers to the network which can help improve the results further. This time, we haven't multiplied user and item embeddings. Instead, we have merged them before feeding them to dense layers.

Create Model

Our model for this section first creates embedding layers for users and items in setup() method. It then declared two dense layers of units 100 and 1. The dense layer with a single output unit will produce our output.

Inside of call() method, we are first applying users embedding layer to user ids and then items embedding layer to item ids. Then, we are merging user embeddings and item embeddings by stacking them next to each other. The resulting matrix will of shape (n_samplesx100) resulted from merging user embeddings shape (n_samplesx50) and item embeddings of shape (n_samplesx50). We have then fed these merged embeddings to our dense layer with 100 units and applied relu (rectified linear unit) activation function to the output. Then, we have fed the output of relu to dense layer with 1 unit which will be our prediction (rating prediction). Now, our model has extra weights to learn due to the addition of 2 dense layers which can further improve the results of the network.

from flax import linen

n_factors = 50

class RecSystemWithDenseLayers(linen.Module):
    n_users = n_users
    n_items = n_items
    n_factors = n_factors

    def setup(self):
        self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
        self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(1, name="Dense2")

    def __call__(self, X_batch):
        users = self.user_embeddings(X_batch[:,0])
        items = self.item_embeddings(X_batch[:,1])
        comb_embed = jnp.hstack((users,items))

        x = self.linear1(comb_embed)
        x = linen.relu(x)

        x = self.linear2(x)
        return x
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

rec_system = RecSystemWithDenseLayers()

params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : User Embeddings
	Layer Weights : (66351, 50)
Layer Name : Item Embeddings
	Layer Weights : (9926, 50)
Layer Name : Dense1
	Layer Weights : (100, 100), Biases : (100,)
Layer Name : Dense2
	Layer Weights : (100, 1), Biases : (1,)
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

preds.shape
(100, 1)

Train Model

Below, we have trained our new recommendation system for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. The loss value getting printed after completion of all epochs is better than both of our previous models which were merely based on embeddings.

seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001

rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)

final_params = TrainModel(X_train,Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
100%|██████████| 567/567 [01:56<00:00,  4.88it/s]
Train MSELoss : 2.272
Valid MSELoss : 1.493
100%|██████████| 567/567 [01:50<00:00,  5.14it/s]
Train MSELoss : 1.471
Valid MSELoss : 1.483
100%|██████████| 567/567 [01:50<00:00,  5.11it/s]
Train MSELoss : 1.417
Valid MSELoss : 1.443
100%|██████████| 567/567 [01:49<00:00,  5.17it/s]
Train MSELoss : 1.370
Valid MSELoss : 1.436
100%|██████████| 567/567 [01:50<00:00,  5.13it/s]
Train MSELoss : 1.369
Valid MSELoss : 1.374
100%|██████████| 567/567 [01:50<00:00,  5.14it/s]
Train MSELoss : 1.367
Valid MSELoss : 1.415
100%|██████████| 567/567 [01:50<00:00,  5.13it/s]
Train MSELoss : 1.363
Valid MSELoss : 1.369
100%|██████████| 567/567 [01:51<00:00,  5.08it/s]
Train MSELoss : 1.355
Valid MSELoss : 1.375
100%|██████████| 567/567 [01:54<00:00,  4.97it/s]
Train MSELoss : 1.351
Valid MSELoss : 1.448
100%|██████████| 567/567 [01:52<00:00,  5.05it/s]
Train MSELoss : 1.359
Valid MSELoss : 1.339
gc.collect()
21

Predict Ratings

In this section, we have used our new recommendation system with dense layers to make predictions. The results are overall the same as our previous systems.

rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
user_id anime_id user_rating pred_rating name genre type episodes average_anime_rating members
2401263 19880 2904 10 8 Code Geass: Hangyaku no Lelouch R2 Action, Drama, Mecha, Military, Sci-Fi, Super ... TV 25 8.98 572888
497418 19880 19815 10 7 No Game No Life Adventure, Comedy, Ecchi, Fantasy, Game, Super... TV 12 8.47 602291
39491 19880 11757 10 7 Sword Art Online Action, Adventure, Fantasy, Game, Romance TV 25 7.83 893100
2331510 19880 1575 10 8 Code Geass: Hangyaku no Lelouch Action, Mecha, Military, School, Sci-Fi, Super... TV 25 8.83 715151
626895 19880 28121 10 7 Dungeon ni Deai wo Motomeru no wa Machigatteir... Action, Adventure, Comedy, Fantasy, Romance TV 13 7.88 336349
2095991 19880 23233 10 7 Shinmai Maou no Testament Action, Demons, Ecchi, Fantasy, Harem, Romance TV 12 7.11 172321
4270539 19880 31478 10 7 Bungou Stray Dogs Action, Comedy, Mystery, Seinen, Supernatural TV 12 7.76 187805
4263788 19880 27991 10 7 K: Return of Kings Action, Super Power, Supernatural TV 13 7.82 114904
2006895 19880 20785 10 7 Mahouka Koukou no Rettousei Magic, Romance, School, Sci-Fi, Supernatural TV 26 7.76 285317
5603987 19880 856 10 7 Utawarerumono Action, Drama, Fantasy, Sci-Fi TV 26 7.78 91034
3639572 19880 23283 10 7 Zankyou no Terror Psychological, Thriller TV 11 8.26 342893
1465635 19880 9253 10 8 Steins;Gate Sci-Fi, Thriller TV 24 9.17 673572
3682791 19880 28623 10 7 Koutetsujou no Kabaneri Action, Drama, Fantasy, Horror TV 12 7.39 253027
2662790 19880 22663 10 7 Seiken Tsukai no World Break Action, Fantasy, Harem, Romance, School, Super... TV 12 7.16 97401
302681 19880 6702 10 7 Fairy Tail Action, Adventure, Comedy, Fantasy, Magic, Sho... TV 175 8.22 584590

Recommend New Items

In this section, we have recommended new items using our new recommendation system. Many of the items recommended here are almost the same as the ones recommended by previous models.

recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
user_id anime_id pred_rating name genre type episodes average_anime_rating members
0 19880 8353 10 Ketsuinu Comedy TV 13 5.04 270
1 19880 6579 9 Chikyuu Bouei Kazoku Special Comedy, Kids, Mecha, School, Super Power Special 1 6.61 215
2 19880 9819 9 Mak Dau Xiang Dang Dang Comedy Movie 1 6.81 103
3 19880 9817 9 Mak Dau Goo Si Comedy Movie 1 7.02 136
4 19880 4640 9 Maroko Comedy, Sci-Fi Movie 1 6.66 353
5 19880 1033 8 Sennen Joyuu Action, Adventure, Drama, Fantasy, Historical,... Movie 1 8.34 58492
6 19880 7785 8 Yojouhan Shinwa Taikei Mystery, Psychological, Romance TV 11 8.65 122531
7 19880 7761 8 Masuda Kousuke Gekijou Gag Manga Biyori + Comedy TV 12 7.41 1123
8 19880 338 8 Rose of Versailles Adventure, Drama, Historical, Romance, Shoujo TV 40 8.40 32188
9 19880 3371 8 Ginga Eiyuu Densetsu Gaiden: Senoku no Hoshi, ... Action, Military, Sci-Fi, Space OVA 24 8.20 8621

Model 4: Embeddings + Dense Layers

In this section, we have tried one more model where we have used one extra dense layer compared to our previous model. This experiment was performed to check whether adding one extra dense layer is helping improve results further or not.

Create Model

Our model for this example has almost the same code as the previous example with the only difference that we have used on the extra dense layer. We have applied that extra dense layer of 50 units after our dense layer of 100 units.

from flax import linen

n_factors = 50

class RecSystemWithDenseLayers(linen.Module):
    n_users = n_users
    n_items = n_items
    n_factors = n_factors

    def setup(self):
        self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
        self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(50, name="Dense2")
        self.linear3 = linen.Dense(1, name="Dense3")

    def __call__(self, X_batch):
        users = self.user_embeddings(X_batch[:,0])
        items = self.item_embeddings(X_batch[:,1])
        comb_embed = jnp.hstack((users,items))

        x = self.linear1(comb_embed)
        x = linen.relu(x)

        x = self.linear2(x)
        x = linen.relu(x)

        x = self.linear3(x)

        return x
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

rec_system = RecSystemWithDenseLayers()

params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : User Embeddings
	Layer Weights : (66351, 50)
Layer Name : Item Embeddings
	Layer Weights : (9926, 50)
Layer Name : Dense1
	Layer Weights : (100, 100), Biases : (100,)
Layer Name : Dense2
	Layer Weights : (100, 50), Biases : (50,)
Layer Name : Dense3
	Layer Weights : (50, 1), Biases : (1,)
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

preds.shape
(100, 1)

Train Model

Below, we have trained our new recommendation system for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. From the resulting loss values, we can notice that the model performance is almost the same as our previous model. There does not seem to be much of an improvement with the addition of one extra dense layer.

seed = jax.random.PRNGKey(0)

epochs=10
batch_size = 10000
learning_rate=0.001

rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)

final_params = TrainModel(X_train, Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
100%|██████████| 567/567 [02:17<00:00,  4.12it/s]
Train MSELoss : 1.995
Valid MSELoss : 1.484
100%|██████████| 567/567 [02:16<00:00,  4.15it/s]
Train MSELoss : 1.463
Valid MSELoss : 1.469
100%|██████████| 567/567 [02:14<00:00,  4.20it/s]
Train MSELoss : 1.395
Valid MSELoss : 1.458
100%|██████████| 567/567 [02:14<00:00,  4.23it/s]
Train MSELoss : 1.400
Valid MSELoss : 1.479
100%|██████████| 567/567 [02:16<00:00,  4.14it/s]
Train MSELoss : 1.403
Valid MSELoss : 1.674
100%|██████████| 567/567 [02:13<00:00,  4.24it/s]
Train MSELoss : 1.343
Valid MSELoss : 1.631
100%|██████████| 567/567 [02:16<00:00,  4.17it/s]
Train MSELoss : 1.344
Valid MSELoss : 1.454
100%|██████████| 567/567 [02:14<00:00,  4.22it/s]
Train MSELoss : 1.354
Valid MSELoss : 1.450
100%|██████████| 567/567 [02:14<00:00,  4.21it/s]
Train MSELoss : 1.362
Valid MSELoss : 1.373
100%|██████████| 567/567 [02:14<00:00,  4.20it/s]
Train MSELoss : 1.354
Valid MSELoss : 1.386
gc.collect()
21

Predict Ratings

In this section, we have used our new model to make predictions. The results seem a little better compared to our previous models though.

rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
user_id anime_id user_rating pred_rating name genre type episodes average_anime_rating members
2401263 19880 2904 10 9 Code Geass: Hangyaku no Lelouch R2 Action, Drama, Mecha, Military, Sci-Fi, Super ... TV 25 8.98 572888
497418 19880 19815 10 8 No Game No Life Adventure, Comedy, Ecchi, Fantasy, Game, Super... TV 12 8.47 602291
39491 19880 11757 10 8 Sword Art Online Action, Adventure, Fantasy, Game, Romance TV 25 7.83 893100
2331510 19880 1575 10 9 Code Geass: Hangyaku no Lelouch Action, Mecha, Military, School, Sci-Fi, Super... TV 25 8.83 715151
626895 19880 28121 10 8 Dungeon ni Deai wo Motomeru no wa Machigatteir... Action, Adventure, Comedy, Fantasy, Romance TV 13 7.88 336349
2095991 19880 23233 10 8 Shinmai Maou no Testament Action, Demons, Ecchi, Fantasy, Harem, Romance TV 12 7.11 172321
4270539 19880 31478 10 8 Bungou Stray Dogs Action, Comedy, Mystery, Seinen, Supernatural TV 12 7.76 187805
4263788 19880 27991 10 8 K: Return of Kings Action, Super Power, Supernatural TV 13 7.82 114904
2006895 19880 20785 10 8 Mahouka Koukou no Rettousei Magic, Romance, School, Sci-Fi, Supernatural TV 26 7.76 285317
5603987 19880 856 10 8 Utawarerumono Action, Drama, Fantasy, Sci-Fi TV 26 7.78 91034
3639572 19880 23283 10 8 Zankyou no Terror Psychological, Thriller TV 11 8.26 342893
1465635 19880 9253 10 9 Steins;Gate Sci-Fi, Thriller TV 24 9.17 673572
3682791 19880 28623 10 8 Koutetsujou no Kabaneri Action, Drama, Fantasy, Horror TV 12 7.39 253027
2662790 19880 22663 10 8 Seiken Tsukai no World Break Action, Fantasy, Harem, Romance, School, Super... TV 12 7.16 97401
302681 19880 6702 10 8 Fairy Tail Action, Adventure, Comedy, Fantasy, Magic, Sho... TV 175 8.22 584590

Recommend New Items

In this section, we have recommended new items to users using our new system. The majority of items are almost the same as the ones recommended by previous models.

recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
user_id anime_id pred_rating name genre type episodes average_anime_rating members
0 19880 6579 10 Chikyuu Bouei Kazoku Special Comedy, Kids, Mecha, School, Super Power Special 1 6.61 215
1 19880 4640 10 Maroko Comedy, Sci-Fi Movie 1 6.66 353
2 19880 8353 10 Ketsuinu Comedy TV 13 5.04 270
3 19880 8196 10 Kawasaki Frontale x Tentai Senshi Sunred 2nd S... Comedy, Parody, Seinen, Sports, Super Power Special 1 6.29 321
4 19880 9819 10 Mak Dau Xiang Dang Dang Comedy Movie 1 6.81 103
5 19880 6630 10 Asari-chan: Ai no Marchen Shoujo Adventure, Shoujo, Slice of Life Movie 1 6.48 170
6 19880 7785 9 Yojouhan Shinwa Taikei Mystery, Psychological, Romance TV 11 8.65 122531
7 19880 801 9 Ghost in the Shell: Stand Alone Complex 2nd GIG Action, Mecha, Military, Mystery, Police, Sci-... TV 26 8.57 113993
8 19880 3297 9 Aria The Origination Fantasy, Sci-Fi, Shounen, Slice of Life TV 13 8.64 56162
9 19880 263 9 Hajime no Ippo Comedy, Drama, Shounen, Sports TV 75 8.83 157670

This ends our small tutorial explaining how we can create recommender systems using the collaborative filtering approach and Flax (JAX) framework. The results of all models seem to be almost the same but we think that model 3 which had dense layers with embeddings did a little better job compared to others. We would recommend trying different settings to improve results before settling with any approach. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription