Model Training Tutorial¶
Welcome to the model training tutorial! In this tutorial, we will train a neural network to classify tiles from our toy data set and visualize its efficacy. Our model is essentially a wrapper around PyTorch's ResNet 18 deep residual network; the LUNA team modified it to suit their work with tiling the slides.
# setup home directory
import os
HOME = os.environ['HOME']
env DATASET_URL=file:///$HOME/vmount/PRO-12-123/
env: DATASET_URL=file:////home/limr/vmount/PRO-12-123/
Model Training¶
The model will be used to classify tiles into the different tissue types we've annotated (tumor, stroma and fat). These tissue classifier models can be trained using the train_tissue_classifier
CLI tool
!train_tissue_classifier --help
2023-04-04 18:31:40,338 - INFO - root - Initalized logger, log file at: luna.log Usage: train_tissue_classifier [OPTIONS] TILE_DATASET_FPATH Train a tissue classifier model for all tiles in a slide Inputs: tile_dataset_fpath: path to tile dataset parquet table Outputs: ray ExperimentAnalysis dataframe and metadata saved to the output Example: train_tissue_classifier /tables/slides/slide_table -ne 5 -nt torchvision.models.resnet18 -nw 1 -o results/train_tile_classifier_results Options: -o, --output_dir TEXT Path to output directory to save results and logs from Ray -ls, --label_set TEXT Dictionary/json where keys coorespoond to tissue types and values coorespond to numerical values -lc, --label_col TEXT Column name in the input dataframe cooresponding to the tissue type (eg. regional_label) -sc, --stratify_col TEXT Column name in the input dataframe used to stratify the training/validation datasets (eg. id_slide_container or patient_id) -nk, --num_splits TEXT The number of folds used for cross validation -ne, --num_epochs TEXT Number of epochs to train the model for. Can be either a fixed integer or a RayTune grid search -bx, --batch_size TEXT Batch size used train the model. Can be either a fixed integer or a RayTune grid search -lr, --learning_rate TEXT Learning rate used for the ADAM optimizer. Can be either a float or a RayTune distribution -nt, --network TEXT Neural network architecture. Can be either a nn.Module or a RayTune grid search -ug, --use_gpu TEXT Whether or not use use GPUs for model training -cw, --num_cpus_per_worker TEXT Number of CPUs transparent to each worker -gw, --num_gpus_per_worker TEXT Number of GPUs transparent to each worker. Can't be more than num_gpus -ng, --num_gpus TEXT Number of GPUs in total transparent to Ray -nc, --num_cpus TEXT Number of CPUs in total transparent to Ray -nw, --num_workers TEXT Total number of workers. Cooresponds to number of models to train concurrently. -ns, --num_samples TEXT number of trials to run -m, --method_param_path TEXT path to a metadata json/yaml file with method parameters to reproduce results --help Show this message and exit.
This CLI tool has a many command line arguments. The main input is the labled tile dataset, which is the data used to train and valdiate the model. For validation, the tiles are stratified by patient id and by slide id, and the split is contoleled by the num_splits
parameter. The label_set
parameter is used to map the tissue types to numerical quantities. These models can use none, one, or many GPUs/CPUs using Ray. The arguments used to modify the resources are num_gpus, num_cpus, num_workers, num_cpus_per_worker, num_gpus_per_worker
. If you want to experiment with different hyperparameters, you can supply a list of values to certian arguments, such as learning_rate
or batch_size
and Ray will perform a hyperparameter search or sweep accordingly.
In the following example, we're going to train a ResNet18 model (though any model available from PyTorch can be used) for two epochs.
%%bash
train_tissue_classifier ~/vmount/PRO-12-123/datasets/PRO_TILES_LABELED/segments \
--label_set "{'tumor':0, 'stroma':1, 'fat':2}" \
--label_col regional_label --stratify_col slide_id \
--num_epochs 2 --network 'torchvision.models.resnet18' \
--num_splits 2 \
--batch_size 4 \
-lr 1e-4 \
-cw 4 -gw 0 -nw 1 -ng 0 -nc 5 -ns 1 \
--output_dir ../PRO-12-123/tissue_classifier_results
2023-04-04 18:31:44,554 - INFO - root - Initalized logger, log file at: luna.log 2023-04-04 18:31:44,557 - INFO - luna.common.utils - Started CLI Runner wtih <function train_model at 0x7f1f35cb2670> 2023-04-04 18:31:44,559 - INFO - luna.common.utils - Validating params... 2023-04-04 18:31:44,562 - INFO - luna.common.utils - -> Set tile_dataset_fpath (<class 'str'>) = /home/limr/vmount/PRO-12-123/datasets/PRO_TILES_LABELED/segments 2023-04-04 18:31:44,566 - INFO - luna.common.utils - -> Set output_dir (<class 'str'>) = ../PRO-12-123/tissue_classifier_results 2023-04-04 18:31:44,569 - INFO - luna.common.utils - -> Set label_set (<class 'dict'>) = {'tumor': 0, 'stroma': 1, 'fat': 2} 2023-04-04 18:31:44,571 - INFO - luna.common.utils - -> Set label_col (<class 'str'>) = regional_label 2023-04-04 18:31:44,573 - INFO - luna.common.utils - -> Set stratify_col (<class 'str'>) = slide_id 2023-04-04 18:31:44,575 - INFO - luna.common.utils - -> Set num_splits (<class 'int'>) = 2 2023-04-04 18:31:44,577 - INFO - luna.common.utils - -> Set num_epochs (typing.List[int]) = [2] 2023-04-04 18:31:44,580 - INFO - luna.common.utils - -> Set batch_size (typing.List[int]) = [4] 2023-04-04 18:31:44,582 - INFO - luna.common.utils - -> Set learning_rate (typing.List[float]) = [0.0001] 2023-04-04 18:31:44,585 - INFO - luna.common.utils - -> Set network (<class 'str'>) = torchvision.models.resnet18 2023-04-04 18:31:44,589 - INFO - luna.common.utils - -> Set num_cpus_per_worker (<class 'int'>) = 4 2023-04-04 18:31:44,592 - INFO - luna.common.utils - -> Set num_gpus_per_worker (<class 'int'>) = 0 2023-04-04 18:31:44,594 - INFO - luna.common.utils - -> Set num_gpus (<class 'int'>) = 0 2023-04-04 18:31:44,597 - INFO - luna.common.utils - -> Set num_cpus (<class 'int'>) = 5 2023-04-04 18:31:44,599 - INFO - luna.common.utils - -> Set num_workers (<class 'int'>) = 1 2023-04-04 18:31:44,603 - INFO - luna.common.utils - -> Set num_samples (<class 'int'>) = 1 2023-04-04 18:31:44,607 - INFO - luna.common.utils - Expanding inputs... 2023-04-04 18:31:44,610 - INFO - luna.common.utils - Full segment key set: {} 2023-04-04 18:31:44,612 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:31:44,612 - INFO - luna.common.utils - Starting transform::train_model 2023-04-04 18:31:44,612 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:31:44,615 - INFO - train_tissue_classifier - Training a tissue classifier with: network=torchvision.models.resnet18, batch_size=[4], learning_rate=[0.0001] 2023-04-04 18:31:44,618 - INFO - train_tissue_classifier - Initilizing Ray Cluster, with: num_gpus=0, num_workers=1 2023-04-04 18:31:46,516 WARNING services.py:1780 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 67108864 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=6.50gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM. 2023-04-04 18:31:46,668 INFO worker.py:1553 -- Started a local Ray instance. 2023-04-04 18:31:47,634 - INFO - train_tissue_classifier - View Ray Dashboard to see worker logs: 2023-04-04 18:31:47,636 - INFO - train_tissue_classifier - training model 2023-04-04 18:31:47,639 - INFO - train_tissue_classifier - Instantiating Ray Trainer with: num_cpus_per_worker=4, num_gpus_per_worker=0 2023-04-04 18:31:47,642 - INFO - train_tissue_classifier - Trainer logs will be logged in: /home/limr/vmount/PRO-12-123/tissue_classifier_results 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - == Status == 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Current time: 2023-04-04 18:31:52 (running for 00:00:04.55) 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Memory usage on this node: 5.7/20.3 GiB 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Using FIFO scheduling algorithm. 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Resources requested: 5.0/5 CPUs, 0/0 GPUs, 0.0/11.81 GiB heap, 0.0/5.91 GiB objects 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Result logdir: /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - Number of trials: 1/1 (1 RUNNING) 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - +--------------------------+----------+----------------+ 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - | Trial name | status | loc | 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - |--------------------------+----------+----------------| 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - | TorchTrainer_f5502_00000 | RUNNING | 172.21.0.5:458 | 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - +--------------------------+----------+----------------+ 2023-04-04 18:31:52,265 - INFO - train_tissue_classifier - (TrainTrainable pid=458) 2023-04-04 18:31:52,252 - INFO - root - Initalized logger, log file at: luna.log (RayTrainWorker pid=517) 2023-04-04 18:31:54,760 INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=1] (RayTrainWorker pid=517) 2023-04-04 18:31:56,694 - INFO - root - Initalized logger, log file at: luna.log (RayTrainWorker pid=517) 2023-04-04 18:31:56,705 - INFO - train_tissue_classifier - Configuring model training driver function... (RayTrainWorker pid=517) 2023-04-04 18:31:57,147 INFO train_loop_utils.py:255 -- Moving model to device: cpu (RayTrainWorker pid=517) 2023-04-04 18:31:57,159 - INFO - train_tissue_classifier - Starting training procedure
Result for TorchTrainer_f5502_00000: _time_this_iter_s: 75.94302082061768 _timestamp: 1680633192 _training_iteration: 1 date: 2023-04-04_18-33-12 done: false experiment_id: cba2c49339824b0ca538a08eee4ab134 hostname: 74d50108ecab iterations_since_restore: 1 node_ip: 172.21.0.5 pid: 458 time_since_restore: 80.38494443893433 time_this_iter_s: 80.38494443893433 time_total_s: 80.38494443893433 timestamp: 1680633192 timesteps_since_restore: 0 train_Accuracy: 0.6581818461418152 train_F1Score: 0.6581818461418152 train_Precision: 0.6581818461418152 train_Recall: 0.6581818461418152 train_loss: 0.8410024582475856 training_iteration: 1 trial_id: f5502_00000 val_Accuracy: 0.6540229916572571 val_ConfusionMatrix: - - 369 - 1 - 161 - - 28 - 0 - 103 - - 8 - 0 - 200 val_F1Score: 0.6540229916572571 val_Precision: 0.6540229916572571 val_Recall: 0.6540229916572571 val_loss: 0.713255486715961 warmup_time: 1.6474692821502686 Result for TorchTrainer_f5502_00000: _time_this_iter_s: 73.28087425231934 _timestamp: 1680633265 _training_iteration: 2 date: 2023-04-04_18-34-25 done: false experiment_id: cba2c49339824b0ca538a08eee4ab134 hostname: 74d50108ecab iterations_since_restore: 2 node_ip: 172.21.0.5 pid: 458 time_since_restore: 153.66580271720886 time_this_iter_s: 73.28085827827454 time_total_s: 153.66580271720886 timestamp: 1680633265 timesteps_since_restore: 0 train_Accuracy: 0.7418181896209717 train_F1Score: 0.7418181896209717 train_Precision: 0.7418181896209717 train_Recall: 0.7418181896209717 train_loss: 0.47712825411471765 training_iteration: 2 trial_id: f5502_00000 val_Accuracy: 0.682758629322052 val_ConfusionMatrix: - - 900 - 1 - 161 - - 138 - 16 - 108 - - 128 - 16 - 272 val_F1Score: 0.682758629322052 val_Precision: 0.682758629322052 val_Recall: 0.682758629322052 val_loss: 0.8418892386429216 warmup_time: 1.6474692821502686
(RayTrainWorker pid=517) 2023-04-04 18:34:26,387 - INFO - train_tissue_classifier - Completed model training
Trial TorchTrainer_f5502_00000 completed.
2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - == Status == 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Current time: 2023-04-04 18:34:29 (running for 00:02:41.80) 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Memory usage on this node: 6.1/20.3 GiB 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Using FIFO scheduling algorithm. 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Resources requested: 0/5 CPUs, 0/0 GPUs, 0.0/11.81 GiB heap, 0.0/5.91 GiB objects 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Result logdir: /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - Number of trials: 1/1 (1 TERMINATED) 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - +--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------+ 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - | Trial name | status | loc | iter | total time (s) | train_loss | val_loss | _timestamp | 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - |--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------| 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - | TorchTrainer_f5502_00000 | TERMINATED | 172.21.0.5:458 | 2 | 153.666 | 0.477128 | 0.841889 | 1680633265 | 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - +--------------------------+------------+----------------+--------+------------------+--------------+------------+--------------+ 2023-04-04 18:34:29,511 - INFO - train_tissue_classifier - 2023-04-04 18:34:29,517 INFO tune.py:798 -- Total run time: 161.84 seconds (161.78 seconds for the tuning loop). 2023-04-04 18:34:29,541 - INFO - train_tissue_classifier - Finished training 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - train_Accuracy ... logdir 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - 0 0.741818 ... /home/limr/vmount/PRO-12-123/tissue_classifier... 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - 2023-04-04 18:34:29,551 - INFO - train_tissue_classifier - [1 rows x 39 columns] 2023-04-04 18:34:29,609 - INFO - train_tissue_classifier - Output: /home/limr/vmount/PRO-12-123/tissue_classifier_results 2023-04-04 18:34:31,934 - INFO - luna.common.utils - Code block 'transform::train_model' took: 167.3163875049795s 2023-04-04 18:34:31,937 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:34:31,937 - INFO - luna.common.utils - Done with transform, running post-transform functions... 2023-04-04 18:34:31,937 - INFO - luna.common.utils - ------------------------------------------------------------ 2023-04-04 18:34:31,966 - INFO - luna.common.utils - Done.
Results¶
Now that we have a trained model, we can inspect the output
%%bash
ls /home/limr/vmount/PRO-12-123/tissue_classifier_results/TorchTrainer_2023-04-04_18-31-47
TorchTrainer_f5502_00000_0_2023-04-04_18-31-47 basic-variant-state-2023-04-04_18-31-47.json experiment_state-2023-04-04_18-31-47.json trainable.pkl tuner.pkl
!ls -lat ../PRO-12-123/tissue_classifier_results
total 87496 -rw-r--r-- 1 limr limr 12354 Apr 4 18:34 metadata.yml drwxr-xr-x 7 limr limr 224 Apr 4 18:34 TorchTrainer_2023-04-04_18-31-47 -rw-r--r-- 1 limr limr 44788557 Apr 4 18:34 checkpoint_1.pt drwxr-xr-x 17 limr limr 544 Apr 4 18:34 . -rw-r--r-- 1 limr limr 44788557 Apr 4 18:33 checkpoint_0.pt drwxr-xr-x 7 limr limr 224 Apr 4 17:52 TorchTrainer_2023-04-04_17-51-33 drwxr-xr-x 7 limr limr 224 Apr 4 17:46 TorchTrainer_2023-04-04_17-45-03 drwxr-xr-x 7 limr limr 224 Apr 4 17:11 TorchTrainer_2023-04-04_17-10-05 drwxr-xr-x 7 limr limr 224 Apr 4 16:08 TorchTrainer_2023-04-04_16-06-53 drwxr-xr-x 7 limr limr 224 Apr 3 21:43 TorchTrainer_2023-04-03_21-43-48 drwxr-xr-x 7 limr limr 224 Apr 3 21:37 TorchTrainer_2023-04-03_21-37-39 drwxr-xr-x 7 limr limr 224 Apr 3 21:21 TorchTrainer_2023-04-03_21-21-44 drwxr-xr-x 7 limr limr 224 Apr 3 21:09 TorchTrainer_2023-04-03_21-09-44 drwxr-xr-x 6 limr limr 192 Apr 3 21:09 TorchTrainer_2023-04-03_20-57-21 drwxr-xr-x 6 limr limr 192 Apr 3 20:51 TorchTrainer_2023-04-03_20-35-01 drwxr-xr-x 4 limr limr 128 Apr 3 18:02 TorchTrainer_2023-04-03_18-02-38 drwxr-xr-x 8 root root 256 Mar 31 21:43 ..
For every time the model is trained, Ray will put together a set of output directories to manage your runs. You can inspect the results using Ray's ExperimentAnalysis dataframe by loading a particular output directory. This dataframe will store various performance metrics as well as the hyperparameters used to configure the model among other output metadata
from ray.tune import ExperimentAnalysis
RAY_OUTPUT = "TorchTrainer_2023-04-04_18-31-47" # change this to the output folder you want to insepct
output_dir = "../PRO-12-123/tissue_classifier_results"
ray_output_dir = os.path.join(output_dir, RAY_OUTPUT)
analysis = ExperimentAnalysis(ray_output_dir)
display(analysis.results_df)
train_Accuracy | train_Precision | train_Recall | train_F1Score | train_loss | val_Accuracy | val_Precision | val_Recall | val_F1Score | val_ConfusionMatrix | ... | iterations_since_restore | warmup_time | experiment_tag | config/scaling_config/trainer_resources | config/scaling_config/num_workers | config/scaling_config/use_gpu | config/scaling_config/placement_strategy | config/scaling_config/_max_cpu_fraction_per_node | config/scaling_config/resources_per_worker/CPU | config/scaling_config/resources_per_worker/GPU | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
trial_id | |||||||||||||||||||||
f5502_00000 | 0.7418182 | 0.7418182 | 0.7418182 | 0.7418182 | 0.477128 | 0.6827586 | 0.6827586 | 0.6827586 | 0.6827586 | [[900, 1, 161], [138, 16, 108], [128, 16, 272]] | ... | 2 | 1.647469 | 0 | None | 1 | False | PACK | None | 4 | 0 |
1 rows × 38 columns
We can use the output to put together a confusion matrix.
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
label_dict = {'tumor':0, 'stroma':1, 'fat':2}
labels = list(label_dict.keys())
cm = analysis.results_df['val_ConfusionMatrix'].iloc[0]
# normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
df_cm = pd.DataFrame(cm, index=labels, columns=labels)
df_cm
sns.heatmap(df_cm, annot=True)
plt.show()
This output directory directory also contains our model checkpoints checkpoint_*.pt
that we'll need for inference. Now, with our trained model and model checkpoints, we can move on the next notebook!