Open In Colab

Cellpose-SAM: superhuman generalization for cellular segmentation#

Marius Pachitariu, Michael Rariden, Carsen Stringer

paper | code

This notebook explains processing example 2D and 3D images using the Cellpose package on Google Colab using the GPU.

Make sure you have GPU access enabled by going to Runtime -> Change Runtime Type -> Hardware accelerator and selecting GPU#

image.png

Install Cellpose-SAM#

!pip install git+https://www.github.com/mouseland/cellpose.git
Collecting git+https://www.github.com/mouseland/cellpose.git
  Cloning https://www.github.com/mouseland/cellpose.git to /private/var/folders/5t/3zkcp0dd27s3txcmjn8jl96m0000gq/T/pip-req-build-ugkxy_ls
  Running command git clone --filter=blob:none --quiet https://www.github.com/mouseland/cellpose.git /private/var/folders/5t/3zkcp0dd27s3txcmjn8jl96m0000gq/T/pip-req-build-ugkxy_ls
  warning: redirecting to https://github.com/mouseland/cellpose.git/
  Resolved https://www.github.com/mouseland/cellpose.git to commit 15eb3c6831ac19e0948dbc38c11016d11d1aacf3
  Preparing metadata (setup.py) ... ?25l-
 \
 |
 done
?25hRequirement already satisfied: numpy>=1.20.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2.2.6)
Requirement already satisfied: scipy in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.15.3)
Requirement already satisfied: natsort in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (8.4.0)
Requirement already satisfied: tifffile in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.5.26)
Requirement already satisfied: tqdm in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (4.67.1)
Requirement already satisfied: torchvision in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (0.22.1)
Requirement already satisfied: opencv-python-headless in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (4.11.0.86)
Requirement already satisfied: fastremap in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.16.1)
Requirement already satisfied: imagecodecs in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.3.30)
Requirement already satisfied: roifile in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.5.10)
Requirement already satisfied: fill-voids in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2.1.0)
Requirement already satisfied: segment_anything in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.0)
Requirement already satisfied: torch==2.7.1 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torchvision->cellpose==4.0.5.dev23+g15eb3c6) (2.7.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torchvision->cellpose==4.0.5.dev23+g15eb3c6) (11.2.1)
Requirement already satisfied: filelock in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (4.13.2)
Requirement already satisfied: setuptools in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (78.1.1)
Requirement already satisfied: sympy>=1.13.3 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (1.14.0)
Requirement already satisfied: networkx in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.5)
Requirement already satisfied: jinja2 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.1.6)
Requirement already satisfied: fsspec in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (2025.5.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from sympy>=1.13.3->torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from jinja2->torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.0.2)

Check GPU and instantiate model - will download weights.

import numpy as np
from cellpose import models, core, io, plot
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt

io.logger_setup() # run this to get printing of progress

#Check if colab notebook instance has GPU access
if core.use_gpu()==False:
  raise ImportError("No GPU access, change your runtime")

model = models.CellposeModel(gpu=True)
Welcome to CellposeSAM, cellpose v
cellpose version: 	4.0.5.dev23+g15eb3c6 
platform:       	darwin 
python version: 	3.13.0 
torch version:  	2.7.1! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow. 
We encourage users to use GPU/MPS if available. 
2025-06-17 22:56:03,372 [INFO] WRITING LOG OUTPUT TO /Users/ranit/.cellpose/run.log
2025-06-17 22:56:03,372 [INFO] 
cellpose version: 	4.0.5.dev23+g15eb3c6 
platform:       	darwin 
python version: 	3.13.0 
torch version:  	2.7.1
2025-06-17 22:56:03,399 [INFO] ** TORCH MPS version installed and working. **
2025-06-17 22:56:03,400 [INFO] ** TORCH MPS version installed and working. **
2025-06-17 22:56:03,400 [INFO] >>>> using GPU (MPS)
2025-06-17 22:56:04,954 [INFO] >>>> loading model /Users/ranit/.cellpose/models/cpsam

Download example images#

import numpy as np
import matplotlib.pyplot as plt
from cellpose import utils, io

# download example 2D images from website
url = "http://www.cellpose.org/static/data/imgs_cyto3.npz"
filename = "imgs_cyto3.npz"
utils.download_url_to_file(url, filename)

# download 3D tiff
url = "http://www.cellpose.org/static/data/rgb_3D.tif"
utils.download_url_to_file(url, "rgb_3D.tif")

dat = np.load(filename, allow_pickle=True)["arr_0"].item()

imgs = dat["imgs"]
masks_true = dat["masks_true"]

plt.figure(figsize=(8,3))
for i, iex in enumerate([9, 16, 21]):
    img = imgs[iex].squeeze()
    plt.subplot(1,3,1+i)
    plt.imshow(img[0], cmap="gray", vmin=0, vmax=1)
    plt.axis('off')
plt.tight_layout()
plt.show()
  0%|                                                                                                                                | 0.00/21.3M [00:00<?, ?B/s]
  1%|█▏                                                                                                                      | 216k/21.3M [00:00<00:10, 2.01MB/s]
  9%|██████████▉                                                                                                            | 1.96M/21.3M [00:00<00:01, 11.1MB/s]
 24%|████████████████████████████▊                                                                                          | 5.16M/21.3M [00:00<00:00, 21.0MB/s]
 35%|█████████████████████████████████████████▊                                                                             | 7.48M/21.3M [00:00<00:00, 22.3MB/s]
 63%|███████████████████████████████████████████████████████████████████████████▎                                           | 13.4M/21.3M [00:00<00:00, 36.6MB/s]
 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████▊              | 18.7M/21.3M [00:00<00:00, 42.9MB/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21.3M/21.3M [00:00<00:00, 33.8MB/s]

  0%|                                                                                                                                | 0.00/1.63M [00:00<?, ?B/s]
 21%|█████████████████████████▎                                                                                              | 352k/1.63M [00:00<00:00, 3.05MB/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.63M/1.63M [00:00<00:00, 9.09MB/s]

../../_images/3f214f527c8e09ac45e878a6beed4e3d005a29ea2a31976a411f2f01b2a6eb20.png

Run Cellpose-SAM#

masks_pred, flows, styles = model.eval(imgs, niter=1000) # using more iterations for bacteria
2025-06-17 22:56:07,296 [INFO] 0%|          | 0/24 [00:00<?, ?it/s]
2025-06-17 22:56:37,317 [INFO] 46%|####5     | 11/24 [00:30<00:35,  2.73s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[4], line 1
----> 1 masks_pred, flows, styles = model.eval(imgs, niter=1000) # using more iterations for bacteria

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:230, in CellposeModel.eval(self, x, batch_size, resample, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, flow_threshold, cellprob_threshold, do_3D, anisotropy, flow3D_smooth, stitch_threshold, min_size, max_size_fraction, niter, augment, tile_overlap, bsize, compute_masks, progress)
    228 for i in iterator:
    229     tic = time.time()
--> 230     maski, flowi, stylei = self.eval(
    231         x[i], 
    232         batch_size=batch_size,
    233         channel_axis=channel_axis, 
    234         z_axis=z_axis,
    235         normalize=normalize, 
    236         invert=invert,
    237         diameter=diameter[i] if isinstance(diameter, list) or
    238             isinstance(diameter, np.ndarray) else diameter, 
    239         do_3D=do_3D,
    240         anisotropy=anisotropy, 
    241         augment=augment, 
    242         tile_overlap=tile_overlap, 
    243         bsize=bsize, 
    244         resample=resample,
    245         flow_threshold=flow_threshold,
    246         cellprob_threshold=cellprob_threshold, 
    247         compute_masks=compute_masks,
    248         min_size=min_size, 
    249         max_size_fraction=max_size_fraction, 
    250         stitch_threshold=stitch_threshold, 
    251         flow3D_smooth=flow3D_smooth,
    252         progress=progress, 
    253         niter=niter)
    254     masks.append(maski)
    255     flows.append(flowi)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:313, in CellposeModel.eval(self, x, batch_size, resample, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, flow_threshold, cellprob_threshold, do_3D, anisotropy, flow3D_smooth, stitch_threshold, min_size, max_size_fraction, niter, augment, tile_overlap, bsize, compute_masks, progress)
    310 if isinstance(anisotropy, (float, int)) and image_scaling:
    311     anisotropy = image_scaling * anisotropy
--> 313 dP, cellprob, styles = self._run_net(
    314     x, 
    315     augment=augment, 
    316     batch_size=batch_size, 
    317     tile_overlap=tile_overlap, 
    318     bsize=bsize,
    319     do_3D=do_3D, 
    320     anisotropy=anisotropy)
    322 if do_3D:    
    323     if flow3D_smooth > 0:

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:478, in CellposeModel._run_net(self, x, augment, batch_size, tile_overlap, bsize, anisotropy, do_3D)
    476     dP = yf[..., :-1].transpose((3, 0, 1, 2))
    477 else:
--> 478     yf, styles = run_net(self.net, x, bsize=bsize, augment=augment,
    479                         batch_size=batch_size,  
    480                         tile_overlap=tile_overlap, 
    481                         )
    482     cellprob = yf[..., -1]
    483     dP = yf[..., -3:-1].transpose((3, 0, 1, 2))

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/core.py:230, in run_net(net, imgi, batch_size, augment, tile_overlap, bsize, rsz)
    228 for j in range(0, IMGa.shape[0], batch_size):
    229     bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
--> 230     ya0, stylea0 = _forward(net, IMGa[bslc])
    231     if j == 0:
    232         nout = ya0.shape[1]

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/core.py:158, in _forward(net, x)
    156 net.eval()
    157 with torch.no_grad():
--> 158     y, style = net(X)[:2]
    159 del X
    160 y = _from_device(y)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/vit_sam.py:70, in Transformer.forward(self, x)
     68 else:
     69     for blk in self.encoder.blocks:
---> 70         x = blk(x)
     72 x = self.encoder.neck(x.permute(0, 3, 1, 2))
     74 # readout is changed here

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:174, in Block.forward(self, x)
    171     H, W = x.shape[1], x.shape[2]
    172     x, pad_hw = window_partition(x, self.window_size)
--> 174 x = self.attn(x)
    175 # Reverse window partition
    176 if self.window_size > 0:

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:234, in Attention.forward(self, x)
    231 attn = (q * self.scale) @ k.transpose(-2, -1)
    233 if self.use_rel_pos:
--> 234     attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
    236 attn = attn.softmax(dim=-1)
    237 x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:350, in add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
    348 k_h, k_w = k_size
    349 Rh = get_rel_pos(q_h, k_h, rel_pos_h)
--> 350 Rw = get_rel_pos(q_w, k_w, rel_pos_w)
    352 B, _, dim = q.shape
    353 r_q = q.reshape(B, q_h, q_w, dim)

File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:292, in get_rel_pos(q_size, k_size, rel_pos)
    288         x = x[:, :H, :W, :].contiguous()
    289     return x
--> 292 def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    293     """
    294     Get relative positional embeddings according to the relative positions of
    295         query and key sizes.
   (...)    302         Extracted positional embeddings according to relative positions.
    303     """
    304     max_rel_dist = int(2 * max(q_size, k_size) - 1)

KeyboardInterrupt: 

plot results

from cellpose import transforms, plot

titles = [
        "Cellpose", "Nuclei", "Tissuenet", "Livecell", "YeaZ",
         "Omnipose\nphase-contrast", "Omnipose\nfluorescent",
        "DeepBacs"
    ]

plt.figure(figsize=(12,6))
ly = 400
for iex in range(len(imgs)):
    img = imgs[iex].squeeze().copy()
    img = np.clip(transforms.normalize_img(img, axis=0), 0, 1) # normalize images across channel axis
    ax = plt.subplot(3, 8, (iex%3)*8 + (iex//3) +1)
    if img[1].sum()==0:
        img = img[0]
        ax.imshow(img, cmap="gray")
    else:
        # make RGB from 2 channel image
        img = np.concatenate((np.zeros_like(img)[:1], img), axis=0).transpose(1,2,0)
        ax.imshow(img)
    ax.set_ylim([0, min(400, img.shape[0])])
    ax.set_xlim([0, min(400, img.shape[1])])


    # GROUND-TRUTH = PURPLE
    # PREDICTED = YELLOW
    outlines_gt = utils.outlines_list(masks_true[iex])
    outlines_pred = utils.outlines_list(masks_pred[iex])
    for o in outlines_gt:
        plt.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=0.5)
    for o in outlines_pred:
        plt.plot(o[:,0], o[:,1], color=[1,1,0.3], lw=0.75, ls="--")
    plt.axis('off')

    if iex%3 == 0:
        ax.set_title(titles[iex//3])

plt.tight_layout()
plt.show()

Run Cellpose-SAM in 3D#

There are two ways to run cellpose in 3D, this cell shows both, choose which one works best for you.

First way: computes flows from 2D slices and combines into 3D flows to create masks

img_3D = io.imread("rgb_3D.tif")


# 1. computes flows from 2D slices and combines into 3D flows to create masks
masks, flows, _ = model.eval(img_3D, z_axis=0, channel_axis=1,
                                batch_size=32,
                                do_3D=True, flow3D_smooth=1)

Second way: computes masks in 2D slices and stitches masks in 3D based on mask overlap

Note stitching (with stitch_threshold > 0) can also be used to track cells over time.

# 2. computes masks in 2D slices and stitches masks in 3D based on mask overlap
print('running cellpose 2D + stitching masks')
masks_stitched, flows_stitched, _ = model.eval(img_3D, z_axis=0, channel_axis=1,
                                                  batch_size=32,
                                                  do_3D=False, stitch_threshold=0.5)

Results from 3D flows => masks computation

# DISPLAY RESULTS 3D flows => masks
plt.figure(figsize=(15,3))
for i,iplane in enumerate(np.arange(0,75,10,int)):
  img0 = plot.image_to_rgb(img_3D[iplane, [1,0]].copy(), channels=[2,3])
  plt.subplot(1,8,i+1)
  outlines = utils.masks_to_outlines(masks[iplane])
  outX, outY = np.nonzero(outlines)
  imgout= img0.copy()
  imgout[outX, outY] = np.array([255,75,75])
  plt.imshow(imgout)
  plt.title('iplane = %d'%iplane)

Results from stitching

# DISPLAY RESULTS stitching
plt.figure(figsize=(15,3))
for i,iplane in enumerate(np.arange(0,75,10,int)):
  img0 = plot.image_to_rgb(img_3D[iplane, [1,0]].copy(), channels=[2,3])
  plt.subplot(1,8,i+1)
  outlines = utils.masks_to_outlines(masks_stitched[iplane])
  outX, outY = np.nonzero(outlines)
  imgout= img0.copy()
  imgout[outX, outY] = np.array([255,75,75])
  plt.imshow(imgout)
  plt.title('iplane = %d'%iplane)