Advanced Topics
Configuration file
MIST uses a centralized config.json file to store dataset information,
preprocessing parameters, model settings, training hyperparameters, inference
strategies, and evaluation metrics.
This file is automatically generated during the analysis step and is required for all subsequent stages of the pipeline (preprocessing, training, inference, and evaluation).
If you want to customize training or inference, you can edit this file directly. For example, you can safely change the network architecture, loss function, or optimizer.
Editing preprocessing is possible but usually not recommended unless you know
your dataset requires different resampling or normalization. If, for example,
you want to apply a different normalization scheme or preprocessing techniques
that are not available in MIST, then we recommend applying your technique to
your images and saving them as NIfTI files. Then, you can either set the
preprocessing.skip entry to true or use the --no-preprocess flag with the
mist_run_all or mist_preprocess commands.
Below is an example of a valid config.json file.
{
"mist_version": "2.0.0rc0",
"dataset_info": {
"task": "ivygap",
"modality": "mr",
"images": ["t1", "t2", "tc", "fl"],
"labels": [0, 1, 2, 4]
},
"spatial_config": {
"patch_size": [128, 128, 128],
"target_spacing": [1.0, 1.0, 1.0]
},
"preprocessing": {
"skip": false,
"crop_to_foreground": true,
"median_resampled_image_size": [135, 174, 139],
"normalize_with_nonzero_mask": true,
"ct_normalization": {
"window_min": null,
"window_max": null,
"z_score_mean": null,
"z_score_std": null
},
"compute_dtms": false,
"normalize_dtms": true
},
"model": {
"architecture": "nnunet",
"params": {
"in_channels": 4,
"out_channels": 4
}
},
"training": {
"seed": 42,
"nfolds": 5,
"folds": [0, 1, 2, 3, 4],
"val_percent": 0.0,
"epochs": 1000,
"min_steps_per_epoch": 250,
"batch_size_per_gpu": 2,
"dali_foreground_prob": 0.6,
"loss": {
"name": "dice_ce",
"composite_loss_weighting": null
},
"optimizer": "adamw",
"learning_rate": 0.001,
"lr_scheduler": "cosine",
"warmup_epochs": 20,
"l2_penalty": 1e-04,
"grad_clip_norm": 1.0,
"amp": true,
"augmentation": {
"enabled": true,
"transforms": {
"flips": true,
"zoom": true,
"noise": true,
"blur": true,
"brightness": true,
"contrast": true
}
},
"hardware": {
"num_gpus": 1,
"num_cpu_workers": 8,
"master_addr": "localhost",
"master_port": 12345,
"communication_backend": "nccl"
}
},
"inference": {
"inferer": {
"name": "sliding_window",
"params": {
"patch_blend_mode": "gaussian",
"patch_overlap": 0.5
}
},
"ensemble": {
"strategy": "mean"
},
"tta": {
"enabled": true,
"strategy": "all_flips"
}
},
"evaluation": {
"wt": {
"labels": [1, 2, 4],
"metrics": {
"dice": {},
"haus95": {}
}
},
"tc": {
"labels": [1, 4],
"metrics": {
"dice": {},
"haus95": {}
}
},
"et": {
"labels": [4],
"metrics": {
"dice": {},
"haus95": {}
}
}
}
}
The config.json file is divided into several major sections. Each section controls
a different part of the MIST pipeline:
| Section | Purpose |
|---|---|
mist_version |
Tracks the version of MIST used to generate the configuration. |
dataset_info |
Metadata about the dataset, including task name, modality, image inputs, and label values. |
spatial_config |
Single source of truth for patch_size and target_spacing, shared across preprocessing, model construction, and inference. |
preprocessing |
Defines how raw images are resampled, cropped, and normalized before training. |
model |
Specifies the model architecture and its hyperparameters (in_channels, out_channels, and any architecture-specific params such as kernel_size for MedNeXt). |
training |
Controls training loop, cross-validation folds, loss function, optimizer, and augmentations. |
inference |
Settings for inference, including sliding-window parameters, ensembling, and test-time augmentation. patch_overlap must be in [0, 1) — a value of 1.0 is invalid. |
evaluation |
Metrics and class definitions for model evaluation. |
Patch size selection
The analysis step automatically selects a patch size and writes it to
spatial_config.patch_size in config.json. The algorithm uses available
GPU memory and the target spacing to choose the best patch for the data. The
patch size can always be overridden by editing config.json or passing
--patch-size to mist_train.
Voxel budget
The per-patch voxel budget is derived from the minimum GPU memory across all available CUDA devices, scaled linearly from a 16 GB / batch-size-2 reference, and inversely proportional to the configured batch size per GPU:
budget = (min_gpu_memory / 16 GB) × 128³ × (2 / batch_size_per_gpu)
This means an 8 GB GPU at batch size 2 yields a budget of 128³ / 2 ≈ 1 M
voxels, a 40 GB A100 at batch size 2 yields roughly 128³ × 2.5 ≈ 160³
voxels, and doubling the batch size to 4 halves the per-patch budget to keep
total memory per step constant. If no GPU is detected at analysis time, the
fallback budget is 128³.
The 128³ reference is a conservative heuristic chosen to leave headroom
for network activations, gradients, and optimizer state — it is not tuned to
a specific architecture. Heavier models (e.g. MedNeXt-large, SwinUNETR) may
require a smaller patch size than the one selected automatically. If training
runs out of memory, reduce patch_size in config.json or via --patch-size.
If there is memory headroom to spare, increasing the patch size may improve
segmentation accuracy.
Note
The budget is only used during analysis to choose the initial patch size.
It is not stored in config.json and does not affect reproducibility —
only the resulting patch_size values matter for training.
3D isotropic mode
Used when max(target_spacing) / min(target_spacing) ≤ 3. A physically
isotropic patch extent is computed from the budget:
target_mm = (budget × prod(target_spacing))^(1/3)
patch[i] = target_mm / target_spacing[i] for each axis i
Any axis whose raw patch would exceed the median resampled image size is clamped to that size; the freed voxels are redistributed to the remaining axes. Every axis is then snapped down to the nearest multiple of 32 ≤ the clamped value (minimum 32), so a clamped axis ends up at the largest multiple of 32 that fits within the median image size. This snapping is required because the nnUNet encoder applies up to 5 stride-2 downsampling steps (2⁵ = 32), and the decoder must reconstruct the output at the original resolution — any dimension not divisible by 32 would produce a size mismatch between encoder and decoder feature maps.
Quasi-2D mode
Used when max(target_spacing) / min(target_spacing) > 3. In highly
anisotropic data (e.g., thick-slice CT where the z-spacing is 5–10× the
in-plane spacing), allocating equal voxels to every axis would waste most of
the budget on the low-resolution axis while under-sampling the high-resolution
in-plane axes. Quasi-2D mode instead maximises in-plane resolution by keeping
the low-resolution axis thin.
The low-resolution axis is whichever axis has the largest spacing (not assumed to be z). Its patch size is chosen as the largest value that still leaves the full budget available for the two in-plane axes:
lr_patch = budget / max(in_plane_median)²
clamped to [MIN_LOW_RES_AXIS_PATCH_SIZE, median_lr]
ip_patch = sqrt(budget / lr_patch)
snapped to nearest multiple of 32, capped at median_ip
Both in-plane axes receive the same patch size (square patch). The low-resolution axis is not snapped to a multiple of 32 — the nnUNet uses stride-1 (no downsampling) in that direction, so feature map sizes are preserved and no divisibility constraint applies. Any size ≥ 1 is valid. The minimum low-resolution patch size is 5, ensuring a 3×3×3 convolution kernel has genuine (non-padded) context at every position along that axis.
Overriding the patch size
If the automatically selected patch is not suitable (e.g., different GPU memory, unusual anatomy), override it in two ways:
In config.json:
"spatial_config": {
"patch_size": [192, 192, 96]
}
Via CLI (overrides config.json):
mist_train --numpy /path/to/numpy --results /path/to/results \
--patch-size 192 192 96
Each architecture imposes its own patch size constraints when overriding manually. The key distinction is between adaptive and fixed topology:
- Adaptive (nnUNet, FMG-Net, W-Net): the architecture inspects
target_spacingat construction time and skips downsampling on any axis that is already low-resolution (applying stride-1 there instead). Those axes accept any patch size ≥ 1. - Fixed (MedNeXt, SwinUNETR): the same stride schedule is applied to every axis regardless of spacing, so all dimensions must meet the divisibility requirement.
| Architecture | Constraint | Notes |
|---|---|---|
nnunet, nnunet-pocket |
Multiples of 32 on downsampled axes | Adaptive: the low-res axis uses stride-1, so any value is valid there |
fmgnet, wnet |
Multiples of 32 on downsampled axes | Same adaptive topology as nnUNet |
mednext-* |
All dimensions divisible by 16 | Fixed 4-stage encoder (2⁴ = 16); no adaptation for anisotropic axes |
swinunetr-* |
All dimensions divisible by 32 | Fixed: patch_size=2 tokenizer × 4 downsampling stages; raises an error at model construction if violated |
The automatic patch selection always produces multiples of 32 (which satisfies all constraints), so manual overrides are the only case where you need to check this.
Network architectures
All MIST models are 3D-only. For highly anisotropic data (e.g., thick-slice
CT), use a thin patch along the anisotropic axis (e.g., 256 256 5) rather
than a 2D patch — the analysis step will choose this automatically.
MIST's default network architecture is the 3D nnUNet with a residual encoder and two deep supervision heads. However the following architectures are also available:
| Architecture | model.architecture |
|---|---|
| nnUNet (default) | nnunet |
| nnUNet (pocket) | nnunet-pocket |
| MedNeXt (small) | mednext-small |
| MedNeXt (base) | mednext-base |
| MedNeXt (medium) | mednext-medium |
| MedNeXt (large) | mednext-large |
| FMG-Net | fmgnet |
| W-Net | wnet |
| SwinUNETR-V2 (small) | swinunetr-small |
| SwinUNETR-V2 (base) | swinunetr-base |
| SwinUNETR-V2 (large) | swinunetr-large |
The architecture can be specified with the --model flag in the mist_run_all
or mist_train commands or directly edited in the config.json file.
A pocket version of nnUNet is available as a dedicated architecture. Rather than
doubling the filter count at each depth level, nnunet-pocket uses a constant
32 filters throughout the network. This reduces parameter count significantly
and can be a good choice. We used this model in the BraTS 2024 and 2025 adult
glioma competition and won 3rd place both years with a pocket model that had
~700K parameters.
Example
Run the MIST training pipeline with the pocket nnUNet.
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--model nnunet-pocket
MedNeXt kernel size
MedNeXt's depthwise convolutions support kernel sizes of 3, 5, and 7.
Larger kernels capture longer-range spatial context and are the primary
differentiator of MedNeXt over nnUNet, at the cost of higher memory usage.
The default is 3, which matches nnUNet's receptive field and is the safest
choice for memory-constrained hardware.
kernel_size |
Receptive field | Memory cost |
|---|---|---|
3 |
Standard | Low |
5 |
Extended | Moderate |
7 |
Large | High |
kernel_size is set via the model.params section of config.json — it is
not exposed as a CLI flag. For example, to use a kernel size of 5 with
MedNeXt base:
"spatial_config": {
"patch_size": [128, 128, 128],
"target_spacing": [1.0, 1.0, 1.0]
},
"model": {
"architecture": "mednext-base",
"params": {
"in_channels": 1,
"out_channels": 2,
"kernel_size": 5
}
}
MedNeXt patch size constraints
MedNeXt uses a fixed 4-stage encoder (blocks_down has length 4 for all
variants). Each stage halves the spatial resolution by stride-2 convolution,
so all patch dimensions must be divisible by 2⁴ = 16. Unlike nnUNet, there
is no adaptive topology — the same stride schedule is applied to every axis
regardless of spacing anisotropy.
The automatic patch selection always produces multiples of 32 (which are also
multiples of 16), so this constraint is satisfied out of the box. If you
manually override patch_size, ensure all three dimensions are multiples of
16. Values that are multiples of 32 also work and are recommended for
consistency with other architectures.
SwinUNETR-V2 patch size constraints
SwinUNETR-V2 requires that all three spatial dimensions of the input patch are
divisible by 32 (the patch_size=2 tokenizer followed by four
downsampling stages: 2 × 2⁴ = 32). Unlike nnUNet and MedNeXt, SwinUNETR
validates this constraint at model construction time and raises a ValueError
immediately if it is violated, rather than failing silently at the first
forward pass.
There is no adaptive topology — all three axes are always processed with the same stride schedule, so the constraint applies to every dimension including the anisotropic (low-res) axis. The automatic patch selection always produces multiples of 32, so this is satisfied out of the box.
FMG-Net and W-Net design
FMG-Net and W-Net are multigrid architectures where representational capacity comes from the network topology rather than channel widening. Both always use the pocket paradigm (constant 32 filters at every depth level) and always have residual blocks and deep supervision enabled.
The two variants differ in their traversal schedule:
- FMG-Net (
fmgnet): progressive schedule — one full V-cycle at each intermediate resolution before diving to the next, producing a staircase of increasing spike heights[1, 2, 3, …, max]. - W-Net (
wnet): sparse W-pattern schedule — alternates shallow and deep V-cycles across the full depth, producing a symmetric pattern like[1, 2, 1, 3, 1, 2, 1].
Both variants adapt their depth and stride schedule to the input patch_size
and target_spacing, the same way nnUNet does.
Adding a custom model
MIST uses a registry pattern to discover model architectures. Adding a new
model requires three steps: implement the model class, register a factory
function, and add an import to mist/models/__init__.py.
Step 1 — Implement the model class.
Create a file mist/models/mymodel/mist_mymodel.py. The class must be a
torch.nn.Module. In training mode, forward should return a dict with a
"prediction" key (and optionally a "deep_supervision" list). In eval mode,
return the logit tensor directly.
Do not apply softmax to model outputs
MIST models must return raw logits — never apply softmax (or any other
activation) to the final output. Softmax is applied inside the loss
function's preprocess() step. Applying it in the model as well will
silently produce incorrect gradients and degraded training performance.
# mist/models/mymodel/mist_mymodel.py
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__()
self.net = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x: torch.Tensor):
output = self.net(x) # Raw logits — no softmax here.
if self.training:
return {"prediction": output}
return output
Step 2 — Register a factory function.
Create mist/models/mymodel/mymodel_registry.py. Use the @register_model
decorator to make the architecture selectable by name.
# mist/models/mymodel/mymodel_registry.py
from mist.models.mymodel.mist_mymodel import MyModel
from mist.models.model_registry import register_model
@register_model("mymodel")
def create_mymodel(**kwargs) -> MyModel:
required_keys = ["in_channels", "out_channels"]
for key in required_keys:
if key not in kwargs:
raise ValueError(f"Missing required key '{key}' in model configuration.")
return MyModel(
in_channels=kwargs["in_channels"],
out_channels=kwargs["out_channels"],
)
Step 3 — Trigger registration at import time.
Add an import to mist/models/__init__.py. Registration happens when the
module is imported, so the factory must be imported before any registry lookup.
# mist/models/__init__.py (add this line)
from mist.models.mymodel.mymodel_registry import create_mymodel
After these three steps, --model mymodel will work with mist_train and
mist_run_all.
Loss functions
MIST's default loss function is the Dice with cross entropy loss function. However, the following loss functions are available in MIST:
| Loss function | training.loss.name |
|---|---|
| Dice w/ cross entropy | dice_ce |
| Dice | dice |
| Boundary | bl |
| One-sided Hausdorff | hdos |
| Generalized surface | gsl |
| clDice | cldice |
| Volumetric surface Dice dilation (experimental) | volumetric_sddl |
| Vessel surface Dice dilation (experimental) | vessel_sddl |
Experimental loss functions
volumetric_sddl and vessel_sddl are experimental. They have not been
validated against a published benchmark and the API is subject to change.
Use dice_ce as a baseline before trying these.
What they do. Both losses measure overlap between dilated predicted
and ground-truth boundaries rather than exact voxel boundaries, tolerating
clinically insignificant spatial errors up to a physical tolerance tau_mm.
Dilation kernels are computed anisotropically from the voxel spacing so
that the tolerance is consistent in physical millimetres across all axes.
volumetric_sddl— Dice+CE + Surface Dice Dilation. Intended for solid organs and tumours.vessel_sddl— clDice + Surface Dice Dilation. Intended for thin, branching structures such as vessels or airways.
Spacing. Both variants read voxel spacing automatically from
spatial_config.target_spacing in config.json — no manual parameter
is required. The surface tolerance defaults to tau_mm: "auto", which
sets it to max(spacing) * 1.25.
DTMs. Neither SDDL variant uses precomputed DTMs — boundaries are
computed directly from model predictions during training. The
--compute-dtms flag is not required.
The loss function can be specified with the --loss flag in the mist_run_all
or mist_train commands or set in the config.json file under the
training.loss.name attribute.
Using distance transform maps
The boundary (bl), one-sided Hausdorff (hdos), and generalized surface
(gsl) loss functions require precomputed distance transform maps (DTMs).
Enable DTM computation with the --compute-dtms flag when running
mist_run_all or mist_preprocess. MIST automatically detects which losses
need DTMs from the loss name — no additional flag is required at training time.
You can also enable DTM computation directly in config.json:
"preprocessing": {
"compute_dtms": true,
"normalize_dtms": true
}
DTM normalization. MIST uses signed distance transforms, where interior voxels (inside the label region) have negative values and exterior voxels (outside) have positive values. The magnitude at any voxel is its distance to the nearest boundary, measured in voxels.
By default, MIST normalizes DTMs per label per volume by dividing interior
distances by the largest interior distance and exterior distances by the largest
exterior distance. This maps the interior to [-1, 0] and the exterior to
[0, 1], with the boundary itself remaining at 0. Normalization is applied
independently for each label, so a small structure and a large structure always
produce DTMs in the same range.
Without normalization, raw distances scale with structure size — a large organ may have interior distances in the hundreds of voxels, producing much larger gradients from the boundary loss term than a small lesion would. This imbalance can destabilize training and is particularly problematic under automatic mixed precision (AMP), where large values risk overflow in float16.
Normalization is strongly recommended and enabled by default. To disable it,
set preprocessing.normalize_dtms to false in config.json.
Composite loss weighting schedules
Several MIST losses are composite — they blend a region-based term (e.g., Dice) with a boundary- or topology-based term via a scalar weight \(\alpha\):
The composite losses in MIST are: bl, hdos, gsl, cldice,
volumetric_sddl, and vessel_sddl. For all other losses (e.g., dice,
dice_ce), composite_loss_weighting has no effect and should be left as
null.
MIST supports three schedules for \(\alpha\):
| Schedule | name |
Description |
|---|---|---|
| Constant | constant |
Fixed \(\alpha\) throughout training. Default value: 0.5. |
| Linear | linear |
Linearly interpolates from start_val to end_val over training, with an optional init_pause (epochs) before decay begins. |
| Cosine | cosine |
Cosine-shaped decay from start_val to end_val, with an optional init_pause. Smoother than linear. |
When to use scheduling. With normalized DTMs (the default), the boundary
term is well-scaled from epoch 0, so a constant \(\alpha\) is usually sufficient
for bl, hdos, and gsl. Scheduling may still help for experimental losses
(volumetric_sddl, vessel_sddl) where the boundary term can be noisier early
in training.
cldice is also a candidate for scheduling. Its topology term is computed by
extracting a soft skeleton from the predicted probability map via iterative
morphological operations. Early in training, when predictions are still noisy
and poorly shaped, the extracted skeleton is unreliable — penalizing based on it
can push the network in unhelpful directions before it has learned a reasonable
approximation of the structure. Starting with higher Dice weight and gradually
increasing the skeleton term's contribution (via a linear or cosine schedule)
gives the network time to develop well-formed predictions before the topology
term becomes the dominant signal.
The table below summarises which losses support scheduling, whether they require
DTMs, and when scheduling is likely to help. All losses default to
\(\alpha = 0.5\) (equal weighting of both terms) when composite_loss_weighting
is null or set to constant with no explicit value.
| Loss | Terms blended | Requires --compute-dtms |
Default \(\alpha\) | When to use scheduling |
|---|---|---|---|---|
bl |
(Dice + CE) + Boundary | Yes | 0.5 | Rarely needed — DTM-normalized boundary term is stable from epoch 0. |
hdos |
(Dice + CE) + One-sided Hausdorff | Yes | 0.5 | Rarely needed — same reason as bl. |
gsl |
(Dice + CE) + Generalized surface | Yes | 0.5 | Rarely needed — same reason as bl. |
cldice |
(Dice + CE) + Soft skeleton (topology) | No | 0.5 | Recommended — skeleton term is unreliable early in training; a linear or cosine schedule lets Dice + CE dominate until predictions stabilize. |
volumetric_sddl (experimental) |
(Dice + CE) + Surface Dice Dilation | No | 0.5 | May help — boundary term can be noisier early in training. |
vessel_sddl (experimental) |
((Dice+CE) + clDice) + Surface Dice Dilation | No | 0.5 | May help — same reason as volumetric_sddl; consider combining with a cldice-style schedule. Note: the same \(\alpha\) is used for both the inner clDice blend and the outer surface blend. |
composite_loss_weighting is stored in config.json as a {name, params}
object under training.loss, or set to null to disable it (equal weighting
of both terms):
"loss": {
"name": "bl",
"composite_loss_weighting": {
"name": "constant",
"params": {
"value": 0.5
}
}
}
"loss": {
"name": "gsl",
"composite_loss_weighting": {
"name": "linear",
"params": {
"init_pause": 5,
"start_val": 1.0,
"end_val": 0.0
}
}
}
"loss": {
"name": "vessel_sddl",
"composite_loss_weighting": {
"name": "cosine",
"params": {
"init_pause": 5,
"start_val": 1.0,
"end_val": 0.0
}
}
}
The --composite-loss-weighting CLI flag accepts the schedule name and
automatically writes the full default parameter set to config.json. Fine-tune
the parameters afterward by editing config.json directly.
init_pause — holding the start value before decay begins
For linear and cosine schedules, init_pause (default: 5) sets the
number of epochs during which \(\alpha\) is held at start_val before any
decay occurs. Decay then proceeds over the remaining epochs.
This is useful when the boundary or topology term needs a few epochs of
stable region-loss supervision before it starts receiving more weight. For
cldice, for example, keeping start_val: 1.0 for 5 epochs ensures the
network learns a reasonable shape before the skeleton term becomes active.
Set init_pause: 0 to begin decay immediately from epoch 0.
Examples
Run the full pipeline with the generalized surface loss and a constant alpha:
mist_run_all --data /path/to/dataset.json \
--numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--loss gsl \
--compute-dtms \
--composite-loss-weighting constant
Run training with the boundary loss and a linear schedule (DTMs already computed):
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--loss bl \
--composite-loss-weighting linear
Run training with clDice and a cosine schedule (no DTMs required):
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--loss cldice \
--composite-loss-weighting cosine
Adding a custom loss
One of MIST's core design goals is extensibility. Adding a new loss function requires no changes to the training loop, CLI, or configuration schema — only three steps are needed.
Step 1 — Implement the loss class.
Create a file mist/loss_functions/losses/my_loss.py. The class must subclass
SegmentationLoss and implement forward. Call self.preprocess() at the
start of forward to convert the raw model outputs and ground truth labels into
the form expected by loss computations.
self.preprocess(y_true, y_pred) does two things:
- Converts
y_truefrom integer class labels of shape(B, 1, H, W, D)to a one-hot float tensor of shape(B, C, H, W, D). - Applies softmax to
y_predalong the channel dimension, converting raw logits of shape(B, C, H, W, D)into class probabilities.
Both outputs are float32 and share the same shape (B, C, H, W, D). If
exclude_background=True was passed at construction, channel 0 is dropped from
both tensors before they are returned.
# mist/loss_functions/losses/my_loss.py
from typing import Any
import torch
from mist.loss_functions.base import SegmentationLoss
from mist.loss_functions.loss_registry import register_loss
@register_loss("my_loss")
class MyLoss(SegmentationLoss):
"""A custom single-component segmentation loss."""
def forward(
self,
y_true: torch.Tensor,
y_pred: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
y_true, y_pred = self.preprocess(y_true, y_pred)
# Implement your loss here using y_true and y_pred.
# Both are shape (B, C, H, W, D) and float32.
# self.spatial_dims_3d = (2, 3, 4) — spatial axes for reductions.
# self.avoid_division_by_zero — small epsilon for numerical stability.
loss = ...
return loss
Step 2 — Trigger registration at import time.
Add an import to mist/loss_functions/__init__.py:
# mist/loss_functions/__init__.py (add this line)
from mist.loss_functions.losses.my_loss import MyLoss
After these two steps, --loss my_loss works with mist_train and
mist_run_all, and my_loss appears in training.loss.name in config.json.
Step 3 — Opt into trainer features (optional).
The trainer uses three frozensets in
mist/training/trainers/trainer_constants.py to gate optional features. Add
your loss name to any that apply:
# mist/training/trainers/trainer_constants.py
DTM_AWARE_LOSSES: FrozenSet[str] = frozenset({"bl", "hdos", "gsl", "my_loss"})
COMPOSITE_LOSSES: FrozenSet[str] = frozenset({..., "my_loss"})
SPACING_AWARE_LOSSES: FrozenSet[str] = frozenset({..., "my_loss"})
| Frozenset | What it enables |
|---|---|
DTM_AWARE_LOSSES |
Precomputed DTMs are loaded and passed as dtm= in forward. Requires --compute-dtms at preprocessing time. |
COMPOSITE_LOSSES |
composite_loss_weighting scheduling becomes available for this loss via --composite-loss-weighting or config.json. |
SPACING_AWARE_LOSSES |
Voxel spacing is read from spatial_config.target_spacing and passed as sddl_spacing_xyz= at construction time. |
Example: composite loss with DTMs
The pattern below is how all built-in composite losses (bl, hdos, gsl)
are implemented — a region term blended with a boundary term using a
schedulable alpha:
@register_loss("my_composite_loss")
class MyCompositeLoss(SegmentationLoss):
"""Custom composite loss blending Dice with a boundary term."""
def forward(
self,
y_true: torch.Tensor,
y_pred: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
dtm = kwargs.get("dtm")
if dtm is None:
raise ValueError("MyCompositeLoss requires a precomputed DTM.")
alpha = kwargs.get("alpha", 0.5)
y_true, y_pred = self.preprocess(y_true, y_pred)
# Region term (e.g., Dice).
region_loss = ...
# Boundary term using the DTM.
boundary_loss = ...
return alpha * region_loss + (1.0 - alpha) * boundary_loss
Then add "my_composite_loss" to both DTM_AWARE_LOSSES and
COMPOSITE_LOSSES in trainer_constants.py. The trainer will automatically
load DTMs and pass them to forward, and --composite-loss-weighting will
control alpha across training.
Optimizers
MIST supports several optimizers commonly used in medical image segmentation:
| Optimizer | training.optimizer |
|---|---|
| Adam | adam |
| AdamW | adamw |
| SGD | sgd |
The learning rate can be adjusted with training.learning_rate entry in the
config.json file. Weight decay is set with training.l2_penalty. These
parameters can also be set with the --learning-rate and --l2-penalty flags
in the mist_run_all or mist_train commands.
Gradient clipping is applied after every backward pass and is controlled by
training.grad_clip_norm in config.json (default: 1.0). Lowering this
value increases regularization and can help stabilize training when using
aggressive learning rates or transformer-based architectures. It is not
exposed as a CLI flag and must be set directly in the configuration file.
Example
Run the MIST training pipeline with the AdamW and adjust the L2 penalty.
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--optimizer adamw \
--l2-penalty 0.0001
Learning rate schedulers
Learning rate scheduling can significantly affect convergence. MIST includes the following schedulers:
| Scheduler | training.lr_scheduler |
|---|---|
| Cosine Annealing | cosine |
| Polynomial decay | polynomial |
| Constant | constant |
The scheduler can be specified with the --lr-scheduler flag or by editing
training.lr_scheduler in the config.json file.
The polynomial learning rate schedule uses a fixed decay rate of 0.9. Future
versions of MIST will make this value configurable.
Linear warmup
Any scheduler can be preceded by a linear warmup phase using --warmup-epochs.
During warmup, the learning rate increases linearly from 1% of the target LR up
to the full LR over the specified number of epochs. The main schedule then runs
for the remaining epochs - warmup_epochs steps, so the full decay budget is
preserved.
LR
| /‾‾‾‾‾‾‾‾‾‾‾‾\
| / \
| / \___
| / warmup | main schedule
+-----|-----------|-----------------> epoch
0 warmup_epochs
The default is warmup_epochs: 20, which works well for the default 1000-epoch
run with AdamW and cosine decay. Warmup is most important when:
- Using transformer-based architectures such as SwinUNETR, where large gradient updates in early epochs can destabilize attention weights.
- Fine-tuning from pretrained encoder weights (
--pretrained-weights), where a sudden full-LR update can damage learned features before the rest of the network has adapted.
For shorter runs or simpler CNN architectures, you can reduce or disable warmup
with --warmup-epochs 0.
Examples
Run with a polynomial learning rate schedule and a higher initial learning rate.
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--learning-rate 0.01 \
--lr-scheduler polynomial
Run with cosine annealing and a 20-epoch linear warmup (the default).
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--lr-scheduler cosine \
--warmup-epochs 20
Or set it directly in config.json:
"training": {
"lr_scheduler": "cosine",
"warmup_epochs": 20
}
Transfer learning
Experimental feature
Transfer learning is still under active development. You are welcome to try it, but you may encounter rough edges. If you run into problems, please open an issue on GitHub.
MIST supports initializing the encoder from a pretrained checkpoint before training. This is useful for:
- Domain adaptation — fine-tuning a model trained on one dataset to a new but related task.
- Few-shot settings — starting from a good initialization when labeled data is scarce.
- Architecture reuse — reusing encoder weights across tasks that share the same input modality and architecture.
Averaging fold weights
After a full cross-validation run, each fold produces a separate model
checkpoint. The mist_average_weights command combines them into a single
initialization checkpoint by element-wise averaging:
mist_average_weights \
--weights /path/to/results/models/fold_0.pt \
/path/to/results/models/fold_1.pt \
/path/to/results/models/fold_2.pt \
/path/to/results/models/fold_3.pt \
/path/to/results/models/fold_4.pt \
--output /path/to/pretrained_init.pt
Averaged weights generalize better than any single fold and are the recommended
input for --pretrained-weights.
Note
This averaged checkpoint is intended for transfer learning only — it is
not used for ensemble inference. For inference, MIST uses the individual
fold models in results/models/ directly.
Fine-tuning from pretrained weights
Pass the pretrained checkpoint and its source config.json to mist_train:
mist_train --numpy /path/to/preprocessed/data \
--results /path/to/results \
--pretrained-weights /path/to/pretrained_init.pt \
--pretrained-config /path/to/source/config.json
MIST validates encoder compatibility between the source and target models before
loading weights. If the input channels differ between source and target, the
--input-channel-strategy flag controls how the mismatch is resolved:
| Strategy | Behaviour |
|---|---|
average |
Take the element-wise mean across all source input channels. (default) |
first |
Use only the first source input channel. |
skip |
Skip the mismatched layer and keep the random initialization. |
Note
--pretrained-config is required when --pretrained-weights is set.
MIST uses it to verify that the source and target architectures are
compatible before loading any weights.
Tip
Combine transfer learning with --warmup-epochs (e.g. 5–10 epochs) to
avoid damaging the pretrained encoder features with a large initial LR
update. See Linear warmup for details.
Docker
The MIST package is also available as a Docker image. Start by pulling the
mistmedical/mist image from DockerHub:
docker pull mistmedical/mist:latest
Use the following command to run an interactive Docker container with the MIST package:
docker run --rm -it -u $(id -u):$(id -g) \
--gpus all \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v /your/working/directory/:/workspace \
mistmedical/mist:latest
From there, you can run any of the commands described above inside the Docker container. Additionally, you can use the Docker entrypoint command to run any of the MIST scripts.
Custom cross validation
By default, MIST uses a random five-fold cross validation split. The folds
are stored in the file ./results/train_paths.csv, which maps each case to its
assigned fold.
You can manually edit this file to customize how folds are assigned. For
example, if you want to use six folds based on data from different
institutions, you can adjust the fold column in train_paths.csv
accordingly and then update the nfolds entry in the configuration file to 6.
This flexibility allows you to:
- Control exactly which patients go into which fold.
- Increase or decrease the number of folds beyond the default five.
- Partition folds according to metadata (e.g., scanner type, acquisition site, or institution) rather than random splits.
Another feature is the ability to adjust the train/validation/test split via the
val_percent entry in the configuration file, which takes values from 0.0 to
1.0. By default it is 0.0, meaning the entire held-out fold is used for both
validation and test. For example, setting 0.025 holds out 2.5% of the training
set for validation while the held-out fold remains the independent test set.
The --val-percent flag can also be used to override this setting at the
command line when running mist_run_all or mist_train. This makes it easy to
experiment with different validation splits without manually editing the
configuration file.
Example
Run the MIST training pipeline with a 5% validation split in addition to the held-out fold:
mist_train --numpy /path/to/preprocessed/npy/files \
--results /path/to/results/folder \
--val-percent 0.05
Data Augmentation
Data augmentation is controlled through the training.augmentation section of
the config.json file. By default, augmentation is enabled with a standard set
of spatial and intensity transforms designed for medical imaging. The available
augmentation transforms are:
| Transform | Description | Config key |
|---|---|---|
| Flips | Random spatial flips along x, y, or z axes | flips |
| Zoom | Random zoom in/out with interpolation | zoom |
| Noise | Additive Gaussian noise to the image | noise |
| Blurring | Random Gaussian blurring | blur |
| Brightness | Random brightness scaling | brightness |
| Contrast | Random contrast adjustment | contrast |
How to customize
- To disable all augmentation, set
"enabled": false. - To toggle individual transforms, set the corresponding key to
trueorfalse. - Augmentation is applied only during training. Test-time augmentation is
controlled in the
inferencesection of the configuration file.
Example
Disable noise and blur while keeping other augmentations active:
"augmentation": {
"enabled": true,
"transforms": {
"flips": true,
"zoom": true,
"noise": false,
"blur": false,
"brightness": true,
"contrast": true
}
}
Foreground sampling probability
MIST uses NVIDIA DALI (Data Loading Library) for GPU-accelerated data loading and patch sampling during training. DALI pipelines run data preprocessing directly on the GPU, overlapping with model computation to reduce per-epoch overhead.
The training.dali_foreground_prob key controls the probability that a
sampled training patch contains foreground (non-zero labels). This is useful in
medical segmentation tasks where most voxels are background, ensuring the model
sees a balanced mix of positive and negative regions during training.
- A value of
0.0means no preference — patches are sampled uniformly from the entire volume. - A value of
1.0means always prefer foreground — every sampled patch will contain at least some non-zero label. - Intermediate values (e.g.,
0.6) bias sampling toward foreground while still including background-only patches.
Default: 0.6
Adjusting this value can improve convergence on highly imbalanced datasets, but setting it too high may reduce the model’s ability to distinguish foreground from background.
Multi-GPU training
MIST uses PyTorch's DistributedDataParallel (DDP) for multi-GPU data
parallelism. MIST will use all GPUs visible to the process. On HPC clusters
(SLURM, LSF, PBS), the job scheduler controls which GPUs are visible via
CUDA_VISIBLE_DEVICES — MIST respects that assignment automatically. On
shared workstations, set CUDA_VISIBLE_DEVICES yourself before running MIST
to restrict which GPUs are used.
training.hardware.master_port is a TCP port number used by PyTorch's
distributed training backend to coordinate the GPU worker processes during
startup. It defaults to 12345. If you run multiple concurrent MIST jobs on
the same machine, each job must use a different master_port value — port
conflicts will cause a job to fail with an "address already in use" error.
Change this value by editing config.json directly.
Parallelism
MIST uses two distinct forms of parallelism depending on the pipeline stage.
Patient-level I/O parallelism
Analysis, preprocessing, evaluation, and postprocessing all operate
patient-by-patient. Each stage exposes a --num-workers-* flag that controls
how many patients are processed in parallel using a Python thread pool:
| Command | Flag | Default |
|---|---|---|
mist_analyze, mist_run_all |
--num-workers-analyze |
1 |
mist_preprocess, mist_run_all |
--num-workers-preprocess |
1 |
mist_train, mist_run_all |
--num-workers-evaluate |
1 |
mist_evaluate |
--num-workers-evaluate |
1 |
mist_postprocess |
--num-workers-postprocess |
1 |
mist_postprocess |
--num-workers-evaluate |
1 |
mist_convert_msd, mist_convert_csv |
--num-workers-conversion |
1 |
mist_train exposes --num-workers-evaluate because it automatically runs
evaluation on the held-out fold predictions after each fold completes. The
flag controls the parallelism for that built-in evaluation step.
Increasing these values speeds up each stage on machines with many CPU cores. A good starting point is to match the number of available CPU cores, capped at the number of patients in the dataset.
Note
Not all stages have the same memory cost per worker. Analysis only reads image headers and samples voxel statistics, so it is lightweight and scales well to many workers. Preprocessing loads each full 3D volume into memory, resamples it, normalizes it, and writes NumPy arrays — each additional worker holds at least one complete image in RAM, more if DTMs are enabled. Evaluation and postprocessing load segmentation masks, which are smaller than raw images but still non-trivial for large volumes. On memory-constrained systems it may be necessary to use a higher worker count for analysis than for preprocessing.
Training data-loading parallelism (DALI)
Training uses a different mechanism. MIST uses NVIDIA DALI to build a GPU-accelerated data pipeline that reads, augments, and assembles batches while the GPU is busy computing gradients. DALI's parallelism is controlled by CPU threads within the pipeline, not by Python workers.
The number of DALI CPU threads is set via training.hardware.num_cpu_workers
in config.json (default: 8). There is no CLI flag for this — it is a
hardware-tuning setting that belongs in the configuration file alongside the
rest of the training hardware parameters. Edit it directly if your machine
has significantly more or fewer CPU cores than the default assumes:
"hardware": {
"num_cpu_workers": 16
}
Note
num_cpu_workers is not the same as the --num-workers-* flags. It does
not control how many patients are loaded in parallel — it controls how many
CPU threads DALI uses internally to prepare the next batch while the current
batch is on the GPU.
Evaluation metrics
The evaluation section of config.json controls which metrics are computed
for each segmentation class during training and when using mist_evaluate
standalone. It is generated automatically by the analysis step but can be
freely edited afterward.
Each key in the section is a class name. Each class specifies its constituent labels and the metrics to compute:
"class_name": {
"labels": [list of label integers],
"metrics": {
"metric_name": { "param": value },
"another_metric": {}
}
}
The value for each metric entry is a dict of keyword arguments passed directly
to the metric function. For metrics with no parameters, use an empty dict {}.
This structure means each class can use a completely different set of metrics
and parameters — there is no requirement to evaluate every class the same way.
Available metrics
| Metric | name |
Parameters |
|---|---|---|
| Dice coefficient | dice |
none |
| 95th percentile Hausdorff distance | haus95 |
none |
| Average surface distance | avg_surf |
none |
| Surface Dice | surf_dice |
tolerance (mm, default: 1.0) |
| BraTS-style lesion-wise Dice | lesion_wise_dice |
see lesion-wise parameters |
| BraTS-style lesion-wise HD95 | lesion_wise_haus95 |
see lesion-wise parameters |
| BraTS-style lesion-wise surface Dice | lesion_wise_surf_dice |
see lesion-wise parameters |
Example: different metrics per class
The following configuration computes only Dice for the whole tumor class, and Dice + surface Dice for the enhancing tumor class:
"evaluation": {
"wt": {
"labels": [1, 2, 4],
"metrics": {
"dice": {}
}
},
"et": {
"labels": [4],
"metrics": {
"dice": {},
"surf_dice": {"tolerance": 1.0}
}
}
}
Example: per-class surface Dice tolerances
Surface Dice tolerance is clinically meaningful — a tight tolerance is appropriate for fine structures like vessels but too strict for larger structures. Each class can use a different tolerance:
"evaluation": {
"wt": {
"labels": [1, 2, 4],
"metrics": {
"dice": {},
"surf_dice": {"tolerance": 3.0}
}
},
"tc": {
"labels": [1, 4],
"metrics": {
"dice": {},
"surf_dice": {"tolerance": 2.0}
}
},
"et": {
"labels": [4],
"metrics": {
"dice": {},
"surf_dice": {"tolerance": 1.0}
}
}
}