Recommender systems are used to predict the ratings given by a particular user to a particular item. Recommendation systems are commonly used by online websites like Amazon, Netflix, Hotstar, etc. These systems try to understand the taste of the user and then recommend items (movies, books, shopping items, etc.) based on it. Collaborative filtering is one of the ways of implementing recommender systems. In collaborative filtering, we represent the user with a vector of floats and items also with a vector of floats of pre-decided size. Each individual user and item are assigned different vectors. Then, we train our neural network to update these vectors such that when we multiply the user vector by item vector, we get a rating result that is near to the actual rating given by the user to that item. We try to bring predictions near to every user's rating of a particular movie. During this process, we learn about user taste as well as item attributes. Once, we have updated the vector after training the user then we can use it to recommend new items to users that he/she has not rated yet. During training, we update vectors of many users and items together hence generally users with the same taste will be recommended the same items.
As a part of this tutorial, we'll explain how we can implement collaborative filtering using Flax framework of Python. Flax is a deep learning framework designed on top of JAX. Deep learning frameworks generally refers to the vectors that we discussed earlier as embeddings and provides embedding layers that maps user/item ids to their vectors. We expect that reader has a background in Flax/JAX. Please feel free to explore the below links if you want to refresh some Flax/JAX content as it'll help with this tutorial.
We'll be using the Japanese Anime dataset available from kaggle as a part of this tutorial. Please feel free to download it to follow along.
Below, we have listed important sections of our tutorial to give a summary of the content covered in this guide.
Below, we have imported the necessary libraries of this tutorial and printed the versions of them that we have used in this tutorial.
import flax
print("FLAX Version : {}".format(flax.__version__))
import jax
print("JAX Version : {}".format(jax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
In this section, we'll be loading ratings and related datasets available from kaggle. We have downloaded the anime dataset from kaggle that we'll be using in our tutorial. We'll load and prepare a dataset for the collaborative filtering task.
The dataset has 2 CSV files.
Below, we have loaded rating data available from rating.csv file. It has user id, anime id, and rating given by the user to that anime. There are entries in the dataframe where the user has watched anime but has not given a rating hence those entries have ratings as -1. We have removed those entries from our data as we want only entries where the user has given a rating.
import pandas as pd
import numpy as np
import gc
ratings_df = pd.read_csv("/kaggle/input/anime-recommendations-database/rating.csv")
ratings_df = ratings_df[ratings_df.rating!=-1]
print("Data : {}".format(ratings_df.shape))
ratings_df.head()
Below, we have loaded anime data from anime.csv file. It has a mapping from anime id to anime name, genre, type, episodes, average rating, and members. We have printed the first few rows to show a glimpse of the dataset.
anime_df = pd.read_csv("/kaggle/input/anime-recommendations-database/anime.csv")
print("Anime : {}".format(anime_df.shape))
anime_df.head()
In this section, we have merged the rating dataset with the anime dataset using the pandas merge() method. We have merged datasets on column anime_id which is common between the two dataframes. After merging, the dataset has two columns giving rating information hence we have renamed one from the rating dataframe as user rating and one from the anime dataframe as average anime rating.
ratings_df = ratings_df.merge(anime_df, on="anime_id")
ratings_df = ratings_df.rename(columns={"rating_x":"user_rating", "rating_y": "average_anime_rating"})
ratings_df.head()
In this section, we are calculating counts of unique items and user ids. We are also checking whether there are any Null values present in the user id, anime id, or rating.
n_items = len(ratings_df['anime_id'].unique())
n_users = len(ratings_df['user_id'].unique())
ratings = ratings_df["user_rating"].unique()
print("Unique Items : {}".format(n_items))
print("Unique Users : {}".format(n_users))
print("Ratings : {}".format(ratings))
print("Are there Null User IDs ? : {}".format(np.alltrue(np.isnan(ratings_df["user_id"].isnull().values))))
print("Are there Null ISBN ? : {}".format(np.alltrue(ratings_df["anime_id"].isnull().values)))
print("Are there Null Ratings ? : {}".format(np.alltrue(ratings_df["user_rating"].isnull().values)))
In this section, we are just looping through all unique user ids and separating user who has given a rating to only one anime from those who has given a rating to more than one animes. We have then filtered our final dataframe and kept only user ids who have given ratings to more than one anime. The reason behind performing this step is that, when we'll divide the dataset into train and test sets, we'll be doing separation based on user ids. If we keep user ids with only ratings then we'll have to keep them in the train set and there won't be an entry for them in the test set. As we have removed entries with one rating, we'll have each user id present in both train and test sets. It's a choice that needs to be made. If the reader wants to keep user ids with only one review then it's fine.
from collections import Counter
single_review_users, multiple_review_users = [], []
less_review_users, more_review_users = [], []
for user_id, cnt in Counter(ratings_df["user_id"].values).items():
if cnt == 1:
less_review_users.append(user_id)
else:
more_review_users.append(user_id)
print("Users with only single review : {}".format(len(less_review_users)))
print("Users with multiple reviews : {}".format(len(more_review_users)))
ratings_df = ratings_df[ratings_df["user_id"].isin(more_review_users)]
In this section, we are simply shuffling the dataframe using the pandas' data frames shuffle() method. We have also created dictionaries that map anime ids to index and their title as they'll be useful in the future. We have then also performed checks to see how many user ids, anime ids, and unique rating types are present after we removed entries with a single review.
unique_items = ratings_df["anime_id"].unique()
item_to_idx = dict(zip(unique_items,range(len(unique_items))))
item_to_title = dict(list(zip(ratings_df["anime_id"].values, ratings_df["name"].values)))
#ratings_df = ratings_df.reset_index()
ratings_df = ratings_df.sample(frac=1.0,random_state=123)
ratings_df.head()
n_items = len(ratings_df['anime_id'].unique())
n_users = len(ratings_df['user_id'].unique())
ratings = ratings_df["user_rating"].unique()
print("Unique Items : {}".format(n_items))
print("Unique Users : {}".format(n_users))
print("Ratings : {}".format(ratings))
In this section, we have just pivoted the first 10 entries of our dataframe to a tabular view of ratings given by users to particular anime.
pivoted_df = pd.pivot(ratings_df.head(10), index="user_id", columns="name", values="user_rating")
pivoted_df.head(10)
Please make a NOTE that we are calling gc.collect() a few times in our tutorial as it initiates Python garbage collector which frees up memory.
gc.collect()
In this section, we have created a small function that divides the dataset into train and test sets. For our case, the data features are user id and anime id and target values are ratings. We'll be feeding user id and anime id to our neural network and it'll predict the rating given by that user to the anime.
In order to divide the dataset into train and test sets, we have first taken unique user ids. Then, we are looping through each user id, taking all entries of that user id (all ratings given) dividing them into the train (90% entries), and test (10%) sets. We have also handled the case where there is only one rating per user then we keep it in train set though we have removed entries with single ratings. If the reader keeps entries with single ratings then the below code will handle it as well.
Our datasets (X_train, X_test) have two entries (user id, anime id). The target values (Y_train, Y_test) are ratings.
def train_test_split(df):
unique_users = np.unique(df["user_id"].values)
X_train, X_test, Y_train, Y_test = [], [], [], []
for i, user_id in enumerate(unique_users):
ratings_temp = df[df["user_id"] == user_id]
if ratings_temp.shape[0]==1: ## If only one sample per user then give it to train set
X_train.extend(ratings_temp[["user_id","anime_id"]].values.tolist())
Y_train.append(ratings_temp["user_rating"].values[0])
else:
idx = int(ratings_temp.shape[0]* 0.9) ## 90% train and 10% test
## Populate train data
X_train.extend(ratings_temp[["user_id","anime_id"]].values[:idx].tolist())
Y_train.extend(ratings_temp["user_rating"].values[:idx].tolist())
## Populate test data
X_test.extend(ratings_temp[["user_id","anime_id"]].values[idx:].tolist())
Y_test.extend(ratings_temp["user_rating"].values[idx:].tolist())
if (i+1)%10000==0:
print("{} users completed.".format(i+1))
return np.array(X_train), np.array(X_test), np.array(Y_train), np.array(Y_test)
%time X_train, X_test, Y_train, Y_test = train_test_split(ratings_df)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Below, we have again performed simple checks for the count of user ids that are present in train but not in test and vice-versa. This is done to verify that the above code properly divided data into train and test sets.
train_users = set(np.unique(X_train[:,0]))
test_users = set(np.unique(X_test[:,0]))
train_not_test = train_users.difference(test_users)
test_not_train = test_users.difference(train_users)
print("Users in Train but not Test : {}".format(len(train_not_test)))
print("Users in Test but not Train : {}".format(len(test_not_train)))
In this section, we'll design our neural network using embeddings. Embeddings are weights that map ids (user id/ item id) to their vector values. As we had discussed earlier, in collaborating filtering, we mentioned that we keep a vector of floats for each unique user and item. Let's say there are 10 unique users and we are maintaining a vector of 50 floats per user then we'll have 10 such vectors of 50 floats, one for each user. Same way vectors will be maintained for items (movies, animes, books, etc). These vectors are generally referred to as embeddings. These embeddings for user generally represents the taste of that user (he/she likes action, adventure movies or comedy movies, or horror movies). The same way embeddings for items represent the characteristics of that item (whether the movie is horror or action or adventure.). The embedding layer for user ids in the network has embeddings for each user and it then maps the user id to its embeddings. The same goes for the embedding layer for item ids. It just maps user/item id to their embeddings (vector of floats). We generally decide the size of these vectors from the beginning and it'll be the same for all user ids and item ids.
In this section, we have defined our simple network for the collaborative filtering task. We have set vector size for each user and item to 50 which means that each individual user and item will be represented with 50 floats. The number of floats is generally referred to as factors and is one of the network hyperparameters that needs to be tuned. We had tried different values for factors like 40, 50, 70, 100, and 120. Then, we decided on 50 as our final value.
We have created a model by extending flax.linen.Module class. The setup() method has layers initialized in them and call() has actual forward pass through the network. We have declared two embedding layers (one for users and one for items) using Embed layer of flax.linen. The user embedding layer has a float array of shape (n_users, n_factors) which will map the user id to its vector of floats. The item embedding layer has a float array of shape (n_items, n_factors) which will map item id to its vector of floats.
The forward pass-through network (call() method) simply multiplies user embeddings by item embeddings to generate a rating. The network will take as input tuple of (user_id, item_id), it'll then take the vector for that user id and item id, multiplies them, and sum up to predict rating given by the user to item. It can do this for a batch of user ids and item ids. We then update the weights of users and items to improve rating prediction during the training process.
If the process of network creation and training seems a little hard to grasp then we suggest that the reader goes through our tutorial on Flax that explains how to create neural networks as it'll help with this one as well.
We have treated the process of predicting rating as REGRESSION task even though ratings are unique with an integer in the range [1,10]. The reason behind treating the rating prediction task as regression is that we want a model to predict a rating that will be near the actual rating. We can then round up float prediction to an integer. The classification task treats all classes as independent of each other and it does not enforce ORDERING in target classes. It can predict any of the possible classes in case of wrong predictions whereas if we treat the task as a regression task then it'll predict a rating that is near to the actual rating even in case of wrong predictions. In our case, the ratings have ordering where a rating of 2 is better than 3, 3 is better than 4, etc. We can enforce ordering by treating the task as REGRESSION which will try to predict a rating that is near to actual.
Please make a NOTE that all of our models will approach the rating prediction task as REGRESSION task even though ratings are independent classes and seem like CLASSIFICATION task. The reason being we can't enforce an order by treating it as a classification task.
from flax import linen
n_factors = 50
class SimpleRecSystem(linen.Module):
n_users = n_users
n_items = n_items
n_factors = n_factors
def setup(self):
self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
def __call__(self, X_batch):
users = self.user_embeddings(X_batch[:,0])
items = self.item_embeddings(X_batch[:,1])
return (users * items).sum(axis=1)
Below, we have initialized the network and printed the shape of embeddings. We have also performed a forward pass-through network for verification purposes.
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
rec_system = SimpleRecSystem()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
preds.shape
In this section, we have defined the loss function that we have used in our case. We have used mean squared error loss as our loss which tries to reduce the distance between actual rating and predicted rating. We have also defined mean absolute error as well as that can also be used as a loss function for regression tasks.
def MSELoss(params, input_data, actual):
preds = rec_system.apply(params, input_data)
return jnp.sqrt(jnp.power(actual.squeeze() - preds.squeeze(), 2).mean())
def MAELoss(params, input_data, actual):
preds = rec_system.apply(params, input_data)
return jnp.abs(actual.squeeze() - preds.squeeze()).mean()
In this section, we have trained our neural network of embeddings. We have created a simple function that we have used for training. The function simply loops through train data in batches. For each batch (a bunch of (user id, item id)), it predicts ratings, calculates loss, calculates gradients, and updates weights using gradients. It also calculates the loss of our test dataset which we have provided as a validation set. The function prints both train and validation loss after each epoch. At last, the function returns updated network parameters as well.
from jax import value_and_grad
from tqdm import tqdm
def TrainModel(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=256):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(MSELoss)(params, X_batch,Y_batch) ## Forward Pass and Loss Calculation
## Update Weights
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates)
losses.append(loss) ## Record Loss
print("Train MSELoss : {:.3f}".format(jnp.array(losses).mean()))
val_loss = MSELoss(params, X_val, Y_val)
print("Valid MSELoss : {:.3f}".format(val_loss))
gc.collect()
return params
Below, we are actually training our network using a function defined above. We have set the number of epochs to 10, batch size to 10,000, and learning rate to 0.001. We have then initialized the network and its weights. Then, We have initialized Adam optimizer to update network parameters in our case. At last, we have called our function to train network.
We can notice from the loss value getting printed after each epoch that our model seems to be improving after each epoch. The last train and validation loss are around 1.4 which indicates that our rating predictions are generally off by 1.4 from the actual rating which is near to the actual rating for the majority of samples.
Later on, we'll try more models to bring down this gap further low (below 1.4).
seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001
rec_system = SimpleRecSystem()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)
final_params = TrainModel(X_train, Y_train, X_test, Y_test, epochs, params, optimizer_state, batch_size)
gc.collect()
In this section, we are using our trained model to make a prediction of ratings for existing user id. We have designed a function for this. The function randomly selects one user id. It then makes predictions for items that are reviewed by this user id using our trained model. We have then merged actual rating and predicted rating in a single dataframe along with other detail which can be helpful. We have also included logic that converts float ratings to integer classes. We'll be reusing this function in our upcoming sections when we try different models.
import math
def make_predictions(rng, rec_system, final_params):
user_id = rng.choice(more_review_users, size=1)[0]
cols_of_interest = ["user_id","anime_id","user_rating","name","genre","type","episodes","average_anime_rating","members"]
user_data = ratings_df[ratings_df["user_id"] == user_id][cols_of_interest]
preds = rec_system.apply(final_params, user_data[["user_id","anime_id"]].values)
preds = jnp.clip(preds,1, 10.5)
preds = np.array([math.ceil(pred) if (pred-int(pred)) > 0.5 else math.floor(pred) for pred in preds])
user_data["pred_rating"] = preds
cols_of_interest.insert(3, "pred_rating")
return user_data[cols_of_interest], user_id
rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
In this section, we are recommending new items to existing users. As our users have not used all items, we can recommend new items to them by using our trained model which knows a little about the taste of users. We have created a small function to recommend new items to users.
The function takes as input user id, user data, recommender model, and model parameters. We then retrieve item ids that are not rated by the input user. Then, we use our trained model to predict the rating for those unrated items and sort them from high rating to low. At last, we recommend items for which our model predicted a high rating. The model tries to predict the rating that the user would have given to an item if he had rated it.
def recommend_new_items(user_id, user_data, rec_system, final_params):
total_item_ids = ratings_df["anime_id"].unique().tolist()
item_ids_user_watched = user_data["anime_id"].values.tolist()
item_ids_not_watched_by_user = list(set(total_item_ids).difference(item_ids_user_watched))
data = np.stack(([user_id]*len(item_ids_not_watched_by_user),item_ids_not_watched_by_user), axis=1)
preds = rec_system.apply(final_params, data)
preds = np.array([math.ceil(pred) if (pred-int(pred)) > 0.5 else math.floor(pred) for pred in preds])
preds = preds.reshape(-1,1)
recommendation_df = pd.DataFrame(np.hstack((data,preds)),columns=["user_id","anime_id","pred_rating"])
recommendation_df = recommendation_df.sort_values(by=["pred_rating"],ascending=False)
cols_of_interest = ["anime_id","name","genre","type","episodes","average_anime_rating","members"]
recommendation_df = recommendation_df.merge(ratings_df[cols_of_interest].drop_duplicates(), on="anime_id")
return recommendation_df
recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
The very obvious question that can come to an individual's mind when using the collaborative filtering approach for a recommendation system is how to handle new users. In order to handle new users, we need to update our embeddings with embeddings of new users. We also need to rerun our model with embeddings of new users as initially, all embeddings will be random for them. We can train the model daily if we are getting many new users or we can train the model weekly, bi-weekly, or even monthly if the frequency of addition of new users is less. Initially, we can recommend items that have overall high ratings given by the majority of users and once the user has watched/rated other items, we can update their embeddings by training network.
Another obvious question that comes to mind is how to handle new items that are new to the platform. In addition, new items will have random embeddings just like new users. One of the solutions can be that we keep on recommending new items randomly to users until we have few ratings on that item. We can then update embeddings of the new item with those few ratings by training our network. We can initially highlight new addition items above to the home pages of websites as done by famous platforms like Netflix, Amazon, etc. We can also send a notification to users who have used the item but have not been given a rating.
In this section, we have introduced bias to embeddings so that it can add one single number to each item/user that can be used to represent more information. It'll also work as a regularization term and help weights settle better.
Our model for this section has much of the code the same as an earlier section with few additions. We have declared user and item bias using Embed layer in setup() method. In call() method, we have first multiplied user embeddings by item embeddings and then added user and item biases to it. At last, we are returning results.
from flax import linen
n_factors = 50
class SimpleRecSystemWithBias(linen.Module):
n_users = n_users
n_items = n_items
n_factors = n_factors
def setup(self):
self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
self.user_bias = linen.Embed(self.n_users, 1, name="User Bias")
self.item_bias = linen.Embed(self.n_items, 1, name="Item Bias")
def __call__(self, X_batch):
users = self.user_embeddings(X_batch[:,0])
items = self.item_embeddings(X_batch[:,1])
result = (users * items).sum(axis=1)
result += self.user_bias(X_batch[:,0]).squeeze() + self.item_bias(X_batch[:,1]).squeeze()
return result
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
rec_system = SimpleRecSystemWithBias()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
preds.shape
In this section, we are training our new recommendation system with bias. We are a training network for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. From the loss value getting printed at the end of epochs, we can notice that this model has seem to have done little better than our earlier model.
seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001
rec_system = SimpleRecSystemWithBias()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)
final_params = TrainModel(X_train,Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
gc.collect()
In this section, we are making predictions for items that the user has seen and rated. From the results, it seems that this model has done almost the same job as our previous model.
rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
In this section, we are recommending new items to the user.
recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
Both of our previous models used only embedding layers and operations between them. But in this example, we'll extend the concept and add dense layers to the network which can help improve the results further. This time, we haven't multiplied user and item embeddings. Instead, we have merged them before feeding them to dense layers.
Our model for this section first creates embedding layers for users and items in setup() method. It then declared two dense layers of units 100 and 1. The dense layer with a single output unit will produce our output.
Inside of call() method, we are first applying users embedding layer to user ids and then items embedding layer to item ids. Then, we are merging user embeddings and item embeddings by stacking them next to each other. The resulting matrix will of shape (n_samplesx100) resulted from merging user embeddings shape (n_samplesx50) and item embeddings of shape (n_samplesx50). We have then fed these merged embeddings to our dense layer with 100 units and applied relu (rectified linear unit) activation function to the output. Then, we have fed the output of relu to dense layer with 1 unit which will be our prediction (rating prediction). Now, our model has extra weights to learn due to the addition of 2 dense layers which can further improve the results of the network.
from flax import linen
n_factors = 50
class RecSystemWithDenseLayers(linen.Module):
n_users = n_users
n_items = n_items
n_factors = n_factors
def setup(self):
self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
self.linear1 = linen.Dense(100, name="Dense1")
self.linear2 = linen.Dense(1, name="Dense2")
def __call__(self, X_batch):
users = self.user_embeddings(X_batch[:,0])
items = self.item_embeddings(X_batch[:,1])
comb_embed = jnp.hstack((users,items))
x = self.linear1(comb_embed)
x = linen.relu(x)
x = self.linear2(x)
return x
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
if "Embedding" in layer_params[0]:
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
else:
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
preds.shape
Below, we have trained our new recommendation system for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. The loss value getting printed after completion of all epochs is better than both of our previous models which were merely based on embeddings.
seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001
rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)
final_params = TrainModel(X_train,Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
gc.collect()
In this section, we have used our new recommendation system with dense layers to make predictions. The results are overall the same as our previous systems.
rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
In this section, we have recommended new items using our new recommendation system. Many of the items recommended here are almost the same as the ones recommended by previous models.
recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
In this section, we have tried one more model where we have used one extra dense layer compared to our previous model. This experiment was performed to check whether adding one extra dense layer is helping improve results further or not.
Our model for this example has almost the same code as the previous example with the only difference that we have used on the extra dense layer. We have applied that extra dense layer of 50 units after our dense layer of 100 units.
from flax import linen
n_factors = 50
class RecSystemWithDenseLayers(linen.Module):
n_users = n_users
n_items = n_items
n_factors = n_factors
def setup(self):
self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")
self.linear1 = linen.Dense(100, name="Dense1")
self.linear2 = linen.Dense(50, name="Dense2")
self.linear3 = linen.Dense(1, name="Dense3")
def __call__(self, X_batch):
users = self.user_embeddings(X_batch[:,0])
items = self.item_embeddings(X_batch[:,1])
comb_embed = jnp.hstack((users,items))
x = self.linear1(comb_embed)
x = linen.relu(x)
x = self.linear2(x)
x = linen.relu(x)
x = self.linear3(x)
return x
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
if "Embedding" in layer_params[0]:
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
else:
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = rec_system.apply(params, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
preds.shape
Below, we have trained our new recommendation system for 10 epochs with a batch size of 10,000 and a learning rate of 0.001. From the resulting loss values, we can notice that the model performance is almost the same as our previous model. There does not seem to be much of an improvement with the addition of one extra dense layer.
seed = jax.random.PRNGKey(0)
epochs=10
batch_size = 10000
learning_rate=0.001
rec_system = RecSystemWithDenseLayers()
params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(params)
final_params = TrainModel(X_train, Y_train, X_test, Y_test, epochs,params,optimizer_state,batch_size)
gc.collect()
In this section, we have used our new model to make predictions. The results seem a little better compared to our previous models though.
rng = np.random.RandomState(1)
user_data, user_id = make_predictions(rng, rec_system, final_params)
user_data.sort_values(by=["user_rating"], ascending=False).head(15)
In this section, we have recommended new items to users using our new system. The majority of items are almost the same as the ones recommended by previous models.
recommendation_df = recommend_new_items(user_id, user_data, rec_system, final_params)
recommendation_df.head(10)
This ends our small tutorial explaining how we can create recommender systems using the collaborative filtering approach and Flax (JAX) framework. The results of all models seem to be almost the same but we think that model 3 which had dense layers with embeddings did a little better job compared to others. We would recommend trying different settings to improve results before settling with any approach. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to