This package contains all the necessary functions to quickly train a model for a collaborative filtering task. Let's start by importing all we'll need.
from fastai.collab import *
Collaborative filtering is when you're tasked to predict how much a user is going to like a certain item. The fastai library contains a
CollabFilteringDataset class that will help you create datasets suitable for training, and a function
get_colab_learner to build a simple model directly from a ratings table. Let's first see how we can get started before delving into the documentation.
For this example, we'll use a small subset of the MovieLens dataset to predict the rating a user would give a particular movie (from 0 to 5). The dataset comes in the form of a csv file where each line is a rating of a movie by a given person.
path = untar_data(URLs.ML_SAMPLE) ratings = pd.read_csv(path/'ratings.csv') ratings.head()
We'll first turn the
movieId columns in category codes, so that we can replace them with their codes when it's time to feed them to an
Embedding layer. This step would be even more important if our csv had names of users, or names of items in it. To do it, we wimply have to call a
CollabDataBunch factory method.
data = CollabDataBunch.from_df(ratings)
learn = collab_learner(data, n_factors=50, y_range=(0.,5.))
And then immediately begin training
learn.fit_one_cycle(5, 5e-3, wd=0.1)
DataBunch for collaborative filtering.
The init function shouldn't be called directly (as it's the one of a basic
DataBunch), instead, you'll want to use the following factory method.
ratings dataframe and splits it randomly for train and test following
pct_val (unless it's None).
rating_name give the names of the corresponding columns (defaults to the first, the second and the third column). Optionally a
test dataframe can be passed an a
seed for the separation between training and validation set. The
kwargs will be passed to
Learner suitable for collaborative filtering.
Creates a simple model with
Embedding weights and biases for
n_factors latent factors. Takes the dot product of the embeddings and adds the bias, then if
y_range is specified, feed the result to a sigmoid rescaled to go from
More specifically, binds
data with a model that is either an
use_nn=False or a
emb_szs otherwise. In both cases the numbers of users and items will be inferred from the data,
y_range can be specifided in the
kwargs and you can pass
wd to the