Distributed model training for large scale data using Dask

A walkthrough in training machine learning models, that rely on terabytes of data, in a distributed fashion using DASK and LightGBM.

Miguel Ángel Cárdenas
6 min readMay 30, 2022
Photo by hp koch on Unsplash

In this post you will find out how to:

  • Use effectively Delayed and Eager Dask API to process large data
  • Train a model though Dask API in ephemeral cluster.
  • Use distributed LightGBM python package

A current trend in the machine learning-powered decision-making process is a Data-centric approach for creating reliable and long-run models. This approach systematically enriches the Dataset by improving the underlying performance while preserving the source code. Whereas dealing with large datasets, with terabytes of data, with a monolithic structure leads to a bottleneck problem, a distributed approach is a more intuitive solution.

Dask parallelizes many libraries within the python ecosystem through flexible APIs and high-performance implementation to overcome the scalability problems of single machine tools. Let’s find out how to get a model trained in a distributed fashion using a few key libraries such as Dask, PyArrow, and LightGBM.

A brief introduction to distributed LightGBM

LightGBM stands for Light Gradient Boosting Machine and was introduced by microsoft research in the end of 2016 as an approach to solve supervised classification problems using gradient boosting where each learner is a random tree. Building tree-based learners is similar with random forest, which is a baseline method and unlike GBM in LightGBM all the trees are trained at once.

Originally written in C++, LightGBM has several extensions and its integration with Dask is provided through the python package. The distributed version of this algorithm isn’t, essentially converted to low-level native operations (i.e. Delayed and Future graph representations), however, Dask’s makes its part in the training process moving data around, while the logic of the operations take place in the LightGBM C++ backend.

One takeaway from this is algorithm is that distributed learning is fast and feasible through the discretization of continuous features. Where every feature is represented by histograms, therefore, each bin interface becomes a candidate split.
Through this compact data, the representation of mutual information “transmitted” between workers is small, therefore, there’s a low latency during loss function processing.
In Dask’s LightGBM version, each worker within a cluster contains a horizontally partitioned slice of the training data where each worker contains other workers’ addresses for a TCP connection. These are decentralized workers, where each one has a rank, and sync only occurs when a new split is created. So that means that each worker computes a gradient histogram and at the end of the training each worker holds an exact copy of the model, so this training is also reliable.

What about a cluster to process with?

One of the key features of Dask is the light transition from local computing to cluster computing supported by several APIs. In this context we'll be using AWS's Fargate without discussing the details of cluster deployment. So lets first create a cluster definition to instantiate either a local cluster for rapid test or parametrized ECS cluster.

cluster setup snippet

For the sake of the demonstrative experiment lets assume that our persisted large Dataset is already partitioned and stored in cloud object storage service. Hence we'll be using the s3 file system library to deal with those data objects. In this particular case, and in order to simplify the most, our Dataset is already pre-processed and stored in the classic format [X: y] that is compatible with most of the supervised learning algorithms available.

And what about data format?

Usually structured data for machine learning applications is rather found in CSV format. However this file type have some underlying problems in big data:
— there are no data types included,
— requires spacing characters,
— the encoding format could be limited (depending of the language).
Regarding Dask implementation, this file format also has slow parsing and cannot be filtered using predicate expressions. A workaround could be create a dataset schema, however, data types should by inferred which could take some time and its prone to invalid data along the columns.
Among other format, parquet is still an excellent option and since Dask uses pyArrow under the hood to speed up reading in a distributed fashion we will be using.

Having said that, the following shell screenshot shows how this data is organized among s3 file system.

cloud stored dataset

Where are we heading to?

Before dig on a minimum working example, lets take a look on how the data pipeline will work using AWS services to process the distributed training. This pipeline will be trigger manually.

Lets load to our cluster some TBs of data

To start with lets define a function to parse the whole Dataset to our Dask cluster though map/reduce strategy where the result will be an eager representation of the graph. Using distributed client we will use a resource list of partitions and a pre-defined set of features to concatenate the large Dataset.

parse dataset snippet

A word in hyper-params optimization

Before moving forward in the model training down stream, we should define a proper method for hyper-parameters search . Grid Search is probably one of the most popular among practitioners, however, since we are dealing with a large Dataset, even, using distributed training, this method will perform poorly due its brute force strategy, therefore its computational complexity which is exponential O(n^k), where k is the number of hyper-params that we are looking for. A wiser solution would be chose Random Search instead that uses random sampling process to approach an optimal solution though a fixed number of iterations at low computational cost, i.e. O(1) order.

Keep track of the experiment

In order to keep track of the experiment lets create a Experiment class. This object will be responsible for create a random sample of the hyper parameters what will be used on model training, since the chosen tuning method was Random Search. Using any tracking API could be useful too, there are a few popular and open source like MLflow.

experiment class

Split `n Train

Once the Dataset is loaded on our cluster is time to apply hold-out strategy to split in train/test partitions. Since RandomSearch method from Scikit-learn package isn't compatible with distributed LightGBM, we're going to implement an iterative method to randomly sample values from the search space.

train model definitions

Wrapping it up!

Now that we’ve built the main building blocks of the train model pipeline, let’s visualize what this would look like for the main script.

In the snippet below, using the functions defined before, creates the main procedure from create a cluster hosted the elastic service Fargate to scoring a model created in a distributed manner.

main call

How to check how the cluster is doing?

Well, diagnosing perfomance in a distributed cluster could be somehow complex, however, Dask comes with a built-in dashboard that can track the execution of each task and check how much computational power is been used by each worker. Dask docs have very solid tips on doing this.

The following GIF displays how this process is performed under 0.5TB of data using 4 workers. The overall time to do this was 20mins approx, where you need to take into account the throughput time between the local Client and the cloud resource manager.

Model training followup

Conclusion

In this post, we discourse about effectively training a classification model using a distributed framework that outperforms monolithic model training.

Happy codding!

--

--

Miguel Ángel Cárdenas
Miguel Ángel Cárdenas

No responses yet