Model Training Tutorial¶
Welcome to the model training tutorial! In this tutorial, we will train a neural network to classify tiles from our toy data set and visualize its efficacy. Our model is essentially a wrapper around PyTorch's ResNet 18 deep residual network; the LUNA team modified it to suit their work with tiling the slides.
# setup home directory
import os
HOME = os.environ['HOME']
env DATASET_URL=file:///$HOME/vmount/PRO-12-123/
env: DATASET_URL=file:////home/limr/vmount/PRO-12-123/
Model Training¶
The model will be used to classify tiles into the different tissue types we've annotated (tumor, stroma and fat). These tissue classifier models can be trained using the train_tissue_classifier CLI tool
!train_tissue_classifier --help
2023-04-04 18:31:40,338 - INFO - root - Initalized logger, log file at: luna.log
Usage: train_tissue_classifier [OPTIONS] TILE_DATASET_FPATH
Train a tissue classifier model for all tiles in a slide
Inputs:
tile_dataset_fpath: path to tile dataset parquet table
Outputs:
ray ExperimentAnalysis dataframe and metadata saved to the output
Example:
train_tissue_classifier /tables/slides/slide_table
-ne 5
-nt torchvision.models.resnet18
-nw 1
-o results/train_tile_classifier_results
Options:
-o, --output_dir TEXT Path to output directory to save results and
logs from Ray
-ls, --label_set TEXT Dictionary/json where keys coorespoond to
tissue types and values coorespond to
numerical values
-lc, --label_col TEXT Column name in the input dataframe
cooresponding to the tissue type (eg.
regional_label)
-sc, --stratify_col TEXT Column name in the input dataframe used to
stratify the training/validation datasets
(eg. id_slide_container or patient_id)
-nk, --num_splits TEXT The number of folds used for cross
validation
-ne, --num_epochs TEXT Number of epochs to train the model for. Can
be either a fixed integer or a RayTune grid
search
-bx, --batch_size TEXT Batch size used train the model. Can be
either a fixed integer or a RayTune grid
search
-lr, --learning_rate TEXT Learning rate used for the ADAM optimizer.
Can be either a float or a RayTune
distribution
-nt, --network TEXT Neural network architecture. Can be either a
nn.Module or a RayTune grid search
-ug, --use_gpu TEXT Whether or not use use GPUs for model
training
-cw, --num_cpus_per_worker TEXT
Number of CPUs transparent to each worker
-gw, --num_gpus_per_worker TEXT
Number of GPUs transparent to each worker.
Can't be more than num_gpus
-ng, --num_gpus TEXT Number of GPUs in total transparent to Ray
-nc, --num_cpus TEXT Number of CPUs in total transparent to Ray
-nw, --num_workers TEXT Total number of workers. Cooresponds to
number of models to train concurrently.
-ns, --num_samples TEXT number of trials to run
-m, --method_param_path TEXT path to a metadata json/yaml file with
method parameters to reproduce results
--help Show this message and exit.
This CLI tool has a many command line arguments. The main input is the labled tile dataset, which is the data used to train and valdiate the model. For validation, the tiles are stratified by patient id and by slide id, and the split is contoleled by the num_splits parameter. The label_set parameter is used to map the tissue types to numerical quantities. These models can use none, one, or many GPUs/CPUs using Ray. The arguments used to modify the resources are num_gpus, num_cpus, num_workers, num_cpus_per_worker, num_gpus_per_worker. If you want to experiment with different hyperparameters, you can supply a list of values to certian arguments, such as learning_rate or batch_size and Ray will perform a hyperparameter search or sweep accordingly.
In the following example, we're going to train a ResNet18 model (though any model available from PyTorch can be used) for two epochs.
%%bash
train_tissue_classifier ~/vmount/PRO-12-123/datasets/PRO_TILES_LABELED/segments \
--label_set "{'tumor':0, 'stroma':1, 'fat':2}" \
--label_col regional_label --stratify_col slide_id \
--num_epochs 2 --network 'torchvision.models.resnet18' \
--num_splits 2 \
--batch_size 4 \
-lr 1e-4 \
-cw 4 -gw 0 -nw 1 -ng 0 -nc 5 -ns 1 \
--output_dir ../PRO-12-123/tissue_classifier_results
2023-04-04 18:31:44,554 - INFO - root - Initalized logger, log file at: luna.log
2023-04-04 18:31:44,557 - INFO - luna.common.utils - Started CLI Runner wtih <function train_model at 0x7f1f35cb2670>
2023-04-04 18:31:44,559 - INFO - luna.common.utils - Validating params...
2023-04-04 18:31:44,562 - INFO - luna.common.utils - -> Set tile_dataset_fpath (<class 'str'>) = /home/limr/vmount/PRO-12-123/datasets/PRO_TILES_LABELED/segments
2023-04-04 18:31:44,566 - INFO - luna.common.utils - -> Set output_dir (<class 'str'>) = ../PRO-12-123/tissue_classifier_results
2023-04-04 18:31:44,569 - INFO - luna.common.utils - -> Set label_set (<class 'dict'>) = {'tumor': 0, 'stroma': 1, 'fat': 2}
2023-04-04 18:31:44,571 - INFO - luna.common.utils - -> Set label_col (<class 'str'>) = regional_label
2023-04-04 18:31:44,573 - INFO - luna.common.utils - -> Set stratify_col (<class 'str'>) = slide_id
2023-04-04 18:31:44,575 - INFO - luna.common.utils - -> Set num_splits (<class 'int'>) = 2
2023-04-04 18:31:44,577 - INFO - luna.common.utils - -> Set num_epochs (typing.List[int]) = [2]
2023-04-04 18:31:44,580 - INFO - luna.common.utils - -> Set batch_size (typing.List[int]) = [4]
2023-04-04 18:31:44,582 - INFO - luna.common.utils - -> Set learning_rate (typing.List[float]) = [0.0001]
2023-04-04 18:31:44,585 - INFO - luna.common.utils - -> Set network (<class 'str'>) = torchvision.models.resnet18
2023-04-04 18:31:44,589 - INFO - luna.common.utils - -> Set num_cpus_per_worker (<class 'int'>) = 4
2023-04-04 18:31:44,592 - INFO - luna.common.utils - -> Set num_gpus_per_worker (<class 'int'>) = 0
2023-04-04 18:31:44,594 - INFO - luna.common.utils - -> Set num_gpus (<class 'int'>) = 0
2023-04-04 18:31:44,597 - INFO - luna.common.utils - -> Set num_cpus (<class 'int'>) = 5
2023-04-04 18:31:44,599 - INFO - luna.common.utils - -> Set num_workers (<class 'int'>) = 1
2023-04-04 18:31:44,603 - INFO - luna.common.utils - -> Set num_samples (<class 'int'>) = 1
2023-04-04 18:31:44,607 - INFO - luna.common.utils - Expanding inputs...
2023-04-04 18:31:44,610 - INFO - luna.common.utils - Full segment key set: {}
2023-04-04 18:31:44,612 - INFO - luna.common.utils - ------------------------------------------------------------
2023-04-04 18:31:44,612 - INFO - luna.common.utils - Starting transform::train_model
2023-04-04 18:31:44,612 - INFO - luna.common.utils - ------------------------------------------------------------
2023-04-04 18:31:44,615 - INFO - train_tissue_classifier - Training a tissue classifier with: network=torchvision.models.resnet18, batch_size=[4], learning_rate=[0.0001]
2023-04-04 18:31:44,618 - INFO - train_tissue_classifier - Initilizing Ray Cluster, with: num_gpus=0, num_workers=1
2023-04-04 18:31:46,516 WARNING services.py:1780 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 67108864 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=6.50gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2023-04-04 18:31:46,668 INFO worker.py:1553 -- Started a local Ray instance.
2023-04-04 18:31:47,634 - INFO - train_tissue_classifier - View Ray Dashboard to see worker logs:
2023-04-04 18:31:47,636 - INFO - train_tissue_classifier - training model
2023-04-04 18:31:47,639 - INFO - train_tissue_classifier - Instantiating Ray Trainer with: num_cpus_per_worker=4, num_gpus_per_worker=0
2023-04-04 18:31:47,642 - INFO - train_tissue_classifier - Trainer logs will be logged in: /home/limr/vmount/PRO-12-123/tissue_classifier_results
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - == Status ==
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Current time: 2023-04-04 18:31:52 (running for 00:00:04.55)
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Memory usage on this node: 5.7/20.3 GiB
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Using FIFO scheduling algorithm.
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Resources requested: 5.0/5 CPUs, 0/0 GPUs, 0.0/11.81 GiB heap, 0.0/5.91 GiB objects
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Result logdir: /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Number of trials: 1/1 (1 RUNNING)
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - +--------------------------+----------+----------------+
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - | Trial name | status | loc |
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - |--------------------------+----------+----------------|
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - | TorchTrainer_f5502_00000 | RUNNING | 172.21.0.5:458 |
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - +--------------------------+----------+----------------+
2023-04-04 18:31:52,265 - INFO - train_tissue_classifier -
(TrainTrainable pid=458) 2023-04-04 18:31:52,252 - INFO - root - Initalized logger, log file at: luna.log
(RayTrainWorker pid=517) 2023-04-04 18:31:54,760 INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=1]
(RayTrainWorker pid=517) 2023-04-04 18:31:56,694 - INFO - root - Initalized logger, log file at: luna.log
(RayTrainWorker pid=517) 2023-04-04 18:31:56,705 - INFO - train_tissue_classifier - Configuring model training driver function...
(RayTrainWorker pid=517) 2023-04-04 18:31:57,147 INFO train_loop_utils.py:255 -- Moving model to device: cpu
(RayTrainWorker pid=517) 2023-04-04 18:31:57,159 - INFO - train_tissue_classifier - Starting training procedure
Result for TorchTrainer_f5502_00000:
_time_this_iter_s: 75.94302082061768
_timestamp: 1680633192
_training_iteration: 1
date: 2023-04-04_18-33-12
done: false
experiment_id: cba2c49339824b0ca538a08eee4ab134
hostname: 74d50108ecab
iterations_since_restore: 1
node_ip: 172.21.0.5
pid: 458
time_since_restore: 80.38494443893433
time_this_iter_s: 80.38494443893433
time_total_s: 80.38494443893433
timestamp: 1680633192
timesteps_since_restore: 0
train_Accuracy: 0.6581818461418152
train_F1Score: 0.6581818461418152
train_Precision: 0.6581818461418152
train_Recall: 0.6581818461418152
train_loss: 0.8410024582475856
training_iteration: 1
trial_id: f5502_00000
val_Accuracy: 0.6540229916572571
val_ConfusionMatrix:
- - 369
- 1
- 161
- - 28
- 0
- 103
- - 8
- 0
- 200
val_F1Score: 0.6540229916572571
val_Precision: 0.6540229916572571
val_Recall: 0.6540229916572571
val_loss: 0.713255486715961
warmup_time: 1.6474692821502686
Result for TorchTrainer_f5502_00000:
_time_this_iter_s: 73.28087425231934
_timestamp: 1680633265
_training_iteration: 2
date: 2023-04-04_18-34-25
done: false
experiment_id: cba2c49339824b0ca538a08eee4ab134
hostname: 74d50108ecab
iterations_since_restore: 2
node_ip: 172.21.0.5
pid: 458
time_since_restore: 153.66580271720886
time_this_iter_s: 73.28085827827454
time_total_s: 153.66580271720886
timestamp: 1680633265
timesteps_since_restore: 0
train_Accuracy: 0.7418181896209717
train_F1Score: 0.7418181896209717
train_Precision: 0.7418181896209717
train_Recall: 0.7418181896209717
train_loss: 0.47712825411471765
training_iteration: 2
trial_id: f5502_00000
val_Accuracy: 0.682758629322052
val_ConfusionMatrix:
- - 900
- 1
- 161
- - 138
- 16
- 108
- - 128
- 16
- 272
val_F1Score: 0.682758629322052
val_Precision: 0.682758629322052
val_Recall: 0.682758629322052
val_loss: 0.8418892386429216
warmup_time: 1.6474692821502686
(RayTrainWorker pid=517) 2023-04-04 18:34:26,387 - INFO - train_tissue_classifier - Completed model training
Trial TorchTrainer_f5502_00000 completed.
2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - == Status == 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Current time: 2023-04-04 18:34:29 (running for 00:02:41.80) 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Memory usage on this node: 6.1/20.3 GiB 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Using FIFO scheduling algorithm. 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Resources requested: 0/5 CPUs, 0/0 GPUs, 0.0/11.81 GiB heap, 0.0/5.91 GiB objects 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Result logdir: /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Number of trials: 1/1 (1 TERMINATED) 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - +--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------+ 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - | Trial name | status | loc | iter | total time (s) | train_loss | val_loss | _timestamp | 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - |--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------| 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - | TorchTrainer_f5502_00000 | TERMINATED | 172.21.0.5:458 | 2 | 153.666 | 0.477128 | 0.841889 | 1680633265 | 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - +--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------+ 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - 2023-04-04 18:34:29,517 INFO tune.py:798 -- Total run time: 161.84 seconds (161.78 seconds for the tuning loop). 2023-04-04 18:34:29,541 - INFO - train_tissue_classifier - Finished training 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - train_Accuracy ... logdir 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - 0 0.741818 ... /home/limr/vmount/PRO-12-123/tissue_classifier... 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - [1 rows x 39 columns] 2023-04-04 18:34:29,609 - INFO - train_tissue_classifier - Output: /home/limr/vmount/PRO-12-123/tissue_classifier_results 2023-04-04 18:34:31,934 - INFO - luna.common.utils - Code block 'transform::train_model' took: 167.3163875049795s 2023-04-04 18:34:31,937 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:34:31,937 - INFO - luna.common.utils - Done with transform, running post-transform functions... 2023-04-04 18:34:31,937 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:34:31,966 - INFO - luna.common.utils - Done.
Results¶
Now that we have a trained model, we can inspect the output
%%bash
ls /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47
TorchTrainer_f5502_00000_0_2023-04-04_18-31-47 basic-variant-state-2023-04-04_18-31-47.json experiment_state-2023-04-04_18-31-47.json trainable.pkl tuner.pkl
!ls -lat ../PRO-12-123/tissue_classifier_results
total 87496 -rw-r--r-- 1 limr limr 12354 Apr 4 18:34 metadata.yml drwxr-xr-x 7 limr limr 224 Apr 4 18:34 TorchTrainer_2023-04-04_18-31-47 -rw-r--r-- 1 limr limr 44788557 Apr 4 18:34 checkpoint_1.pt drwxr-xr-x 17 limr limr 544 Apr 4 18:34 . -rw-r--r-- 1 limr limr 44788557 Apr 4 18:33 checkpoint_0.pt drwxr-xr-x 7 limr limr 224 Apr 4 17:52 TorchTrainer_2023-04-04_17-51-33 drwxr-xr-x 7 limr limr 224 Apr 4 17:46 TorchTrainer_2023-04-04_17-45-03 drwxr-xr-x 7 limr limr 224 Apr 4 17:11 TorchTrainer_2023-04-04_17-10-05 drwxr-xr-x 7 limr limr 224 Apr 4 16:08 TorchTrainer_2023-04-04_16-06-53 drwxr-xr-x 7 limr limr 224 Apr 3 21:43 TorchTrainer_2023-04-03_21-43-48 drwxr-xr-x 7 limr limr 224 Apr 3 21:37 TorchTrainer_2023-04-03_21-37-39 drwxr-xr-x 7 limr limr 224 Apr 3 21:21 TorchTrainer_2023-04-03_21-21-44 drwxr-xr-x 7 limr limr 224 Apr 3 21:09 TorchTrainer_2023-04-03_21-09-44 drwxr-xr-x 6 limr limr 192 Apr 3 21:09 TorchTrainer_2023-04-03_20-57-21 drwxr-xr-x 6 limr limr 192 Apr 3 20:51 TorchTrainer_2023-04-03_20-35-01 drwxr-xr-x 4 limr limr 128 Apr 3 18:02 TorchTrainer_2023-04-03_18-02-38 drwxr-xr-x 8 root root 256 Mar 31 21:43 ..
For every time the model is trained, Ray will put together a set of output directories to manage your runs. You can inspect the results using Ray's ExperimentAnalysis dataframe by loading a particular output directory. This dataframe will store various performance metrics as well as the hyperparameters used to configure the model among other output metadata
from ray.tune import ExperimentAnalysis
RAY_OUTPUT = "TorchTrainer_2023-04-04_18-31-47" # change this to the output folder you want to insepct
output_dir = "../PRO-12-123/tissue_classifier_results"
ray_output_dir = os.path.join(output_dir, RAY_OUTPUT)
analysis = ExperimentAnalysis(ray_output_dir)
display(analysis.results_df)
| train_Accuracy | train_Precision | train_Recall | train_F1Score | train_loss | val_Accuracy | val_Precision | val_Recall | val_F1Score | val_ConfusionMatrix | ... | iterations_since_restore | warmup_time | experiment_tag | config/scaling_config/trainer_resources | config/scaling_config/num_workers | config/scaling_config/use_gpu | config/scaling_config/placement_strategy | config/scaling_config/_max_cpu_fraction_per_node | config/scaling_config/resources_per_worker/CPU | config/scaling_config/resources_per_worker/GPU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| trial_id | |||||||||||||||||||||
| f5502_00000 | 0.7418182 | 0.7418182 | 0.7418182 | 0.7418182 | 0.477128 | 0.6827586 | 0.6827586 | 0.6827586 | 0.6827586 | [[900, 1, 161], [138, 16, 108], [128, 16, 272]] | ... | 2 | 1.647469 | 0 | None | 1 | False | PACK | None | 4 | 0 |
1 rows × 38 columns
We can use the output to put together a confusion matrix.
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
label_dict = {'tumor':0, 'stroma':1, 'fat':2}
labels = list(label_dict.keys())
cm = analysis.results_df['val_ConfusionMatrix'].iloc[0]
# normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
df_cm = pd.DataFrame(cm, index=labels, columns=labels)
df_cm
sns.heatmap(df_cm, annot=True)
plt.show()
This output directory directory also contains our model checkpoints checkpoint_*.pt that we'll need for inference. Now, with our trained model and model checkpoints, we can move on the next notebook!