Cellpose-SAM: superhuman generalization for cellular segmentation#
Marius Pachitariu, Michael Rariden, Carsen Stringer
This notebook explains processing example 2D and 3D images using the Cellpose package on Google Colab using the GPU.
Make sure you have GPU access enabled by going to Runtime -> Change Runtime Type -> Hardware accelerator and selecting GPU#
Install Cellpose-SAM#
!pip install git+https://www.github.com/mouseland/cellpose.git
Collecting git+https://www.github.com/mouseland/cellpose.git
Cloning https://www.github.com/mouseland/cellpose.git to /private/var/folders/5t/3zkcp0dd27s3txcmjn8jl96m0000gq/T/pip-req-build-ugkxy_ls
Running command git clone --filter=blob:none --quiet https://www.github.com/mouseland/cellpose.git /private/var/folders/5t/3zkcp0dd27s3txcmjn8jl96m0000gq/T/pip-req-build-ugkxy_ls
warning: redirecting to https://github.com/mouseland/cellpose.git/
Resolved https://www.github.com/mouseland/cellpose.git to commit 15eb3c6831ac19e0948dbc38c11016d11d1aacf3
Preparing metadata (setup.py) ... ?25l-
\
|
done
?25hRequirement already satisfied: numpy>=1.20.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2.2.6)
Requirement already satisfied: scipy in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.15.3)
Requirement already satisfied: natsort in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (8.4.0)
Requirement already satisfied: tifffile in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.5.26)
Requirement already satisfied: tqdm in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (4.67.1)
Requirement already satisfied: torchvision in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (0.22.1)
Requirement already satisfied: opencv-python-headless in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (4.11.0.86)
Requirement already satisfied: fastremap in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.16.1)
Requirement already satisfied: imagecodecs in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.3.30)
Requirement already satisfied: roifile in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2025.5.10)
Requirement already satisfied: fill-voids in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (2.1.0)
Requirement already satisfied: segment_anything in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from cellpose==4.0.5.dev23+g15eb3c6) (1.0)
Requirement already satisfied: torch==2.7.1 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torchvision->cellpose==4.0.5.dev23+g15eb3c6) (2.7.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torchvision->cellpose==4.0.5.dev23+g15eb3c6) (11.2.1)
Requirement already satisfied: filelock in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (4.13.2)
Requirement already satisfied: setuptools in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (78.1.1)
Requirement already satisfied: sympy>=1.13.3 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (1.14.0)
Requirement already satisfied: networkx in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.5)
Requirement already satisfied: jinja2 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.1.6)
Requirement already satisfied: fsspec in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (2025.5.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from sympy>=1.13.3->torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/ranit/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages (from jinja2->torch==2.7.1->torchvision->cellpose==4.0.5.dev23+g15eb3c6) (3.0.2)
Check GPU and instantiate model - will download weights.
import numpy as np
from cellpose import models, core, io, plot
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt
io.logger_setup() # run this to get printing of progress
#Check if colab notebook instance has GPU access
if core.use_gpu()==False:
raise ImportError("No GPU access, change your runtime")
model = models.CellposeModel(gpu=True)
Welcome to CellposeSAM, cellpose v
cellpose version: 4.0.5.dev23+g15eb3c6
platform: darwin
python version: 3.13.0
torch version: 2.7.1! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow.
We encourage users to use GPU/MPS if available.
2025-06-17 22:56:03,372 [INFO] WRITING LOG OUTPUT TO /Users/ranit/.cellpose/run.log
2025-06-17 22:56:03,372 [INFO]
cellpose version: 4.0.5.dev23+g15eb3c6
platform: darwin
python version: 3.13.0
torch version: 2.7.1
2025-06-17 22:56:03,399 [INFO] ** TORCH MPS version installed and working. **
2025-06-17 22:56:03,400 [INFO] ** TORCH MPS version installed and working. **
2025-06-17 22:56:03,400 [INFO] >>>> using GPU (MPS)
2025-06-17 22:56:04,954 [INFO] >>>> loading model /Users/ranit/.cellpose/models/cpsam
Download example images#
import numpy as np
import matplotlib.pyplot as plt
from cellpose import utils, io
# download example 2D images from website
url = "http://www.cellpose.org/static/data/imgs_cyto3.npz"
filename = "imgs_cyto3.npz"
utils.download_url_to_file(url, filename)
# download 3D tiff
url = "http://www.cellpose.org/static/data/rgb_3D.tif"
utils.download_url_to_file(url, "rgb_3D.tif")
dat = np.load(filename, allow_pickle=True)["arr_0"].item()
imgs = dat["imgs"]
masks_true = dat["masks_true"]
plt.figure(figsize=(8,3))
for i, iex in enumerate([9, 16, 21]):
img = imgs[iex].squeeze()
plt.subplot(1,3,1+i)
plt.imshow(img[0], cmap="gray", vmin=0, vmax=1)
plt.axis('off')
plt.tight_layout()
plt.show()
0%| | 0.00/21.3M [00:00<?, ?B/s]
1%|█▏ | 216k/21.3M [00:00<00:10, 2.01MB/s]
9%|██████████▉ | 1.96M/21.3M [00:00<00:01, 11.1MB/s]
24%|████████████████████████████▊ | 5.16M/21.3M [00:00<00:00, 21.0MB/s]
35%|█████████████████████████████████████████▊ | 7.48M/21.3M [00:00<00:00, 22.3MB/s]
63%|███████████████████████████████████████████████████████████████████████████▎ | 13.4M/21.3M [00:00<00:00, 36.6MB/s]
88%|████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 18.7M/21.3M [00:00<00:00, 42.9MB/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21.3M/21.3M [00:00<00:00, 33.8MB/s]
0%| | 0.00/1.63M [00:00<?, ?B/s]
21%|█████████████████████████▎ | 352k/1.63M [00:00<00:00, 3.05MB/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.63M/1.63M [00:00<00:00, 9.09MB/s]
Run Cellpose-SAM#
masks_pred, flows, styles = model.eval(imgs, niter=1000) # using more iterations for bacteria
2025-06-17 22:56:07,296 [INFO] 0%| | 0/24 [00:00<?, ?it/s]
2025-06-17 22:56:37,317 [INFO] 46%|####5 | 11/24 [00:30<00:35, 2.73s/it]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[4], line 1
----> 1 masks_pred, flows, styles = model.eval(imgs, niter=1000) # using more iterations for bacteria
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:230, in CellposeModel.eval(self, x, batch_size, resample, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, flow_threshold, cellprob_threshold, do_3D, anisotropy, flow3D_smooth, stitch_threshold, min_size, max_size_fraction, niter, augment, tile_overlap, bsize, compute_masks, progress)
228 for i in iterator:
229 tic = time.time()
--> 230 maski, flowi, stylei = self.eval(
231 x[i],
232 batch_size=batch_size,
233 channel_axis=channel_axis,
234 z_axis=z_axis,
235 normalize=normalize,
236 invert=invert,
237 diameter=diameter[i] if isinstance(diameter, list) or
238 isinstance(diameter, np.ndarray) else diameter,
239 do_3D=do_3D,
240 anisotropy=anisotropy,
241 augment=augment,
242 tile_overlap=tile_overlap,
243 bsize=bsize,
244 resample=resample,
245 flow_threshold=flow_threshold,
246 cellprob_threshold=cellprob_threshold,
247 compute_masks=compute_masks,
248 min_size=min_size,
249 max_size_fraction=max_size_fraction,
250 stitch_threshold=stitch_threshold,
251 flow3D_smooth=flow3D_smooth,
252 progress=progress,
253 niter=niter)
254 masks.append(maski)
255 flows.append(flowi)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:313, in CellposeModel.eval(self, x, batch_size, resample, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, flow_threshold, cellprob_threshold, do_3D, anisotropy, flow3D_smooth, stitch_threshold, min_size, max_size_fraction, niter, augment, tile_overlap, bsize, compute_masks, progress)
310 if isinstance(anisotropy, (float, int)) and image_scaling:
311 anisotropy = image_scaling * anisotropy
--> 313 dP, cellprob, styles = self._run_net(
314 x,
315 augment=augment,
316 batch_size=batch_size,
317 tile_overlap=tile_overlap,
318 bsize=bsize,
319 do_3D=do_3D,
320 anisotropy=anisotropy)
322 if do_3D:
323 if flow3D_smooth > 0:
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/models.py:478, in CellposeModel._run_net(self, x, augment, batch_size, tile_overlap, bsize, anisotropy, do_3D)
476 dP = yf[..., :-1].transpose((3, 0, 1, 2))
477 else:
--> 478 yf, styles = run_net(self.net, x, bsize=bsize, augment=augment,
479 batch_size=batch_size,
480 tile_overlap=tile_overlap,
481 )
482 cellprob = yf[..., -1]
483 dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/core.py:230, in run_net(net, imgi, batch_size, augment, tile_overlap, bsize, rsz)
228 for j in range(0, IMGa.shape[0], batch_size):
229 bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
--> 230 ya0, stylea0 = _forward(net, IMGa[bslc])
231 if j == 0:
232 nout = ya0.shape[1]
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/core.py:158, in _forward(net, x)
156 net.eval()
157 with torch.no_grad():
--> 158 y, style = net(X)[:2]
159 del X
160 y = _from_device(y)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/cellpose/vit_sam.py:70, in Transformer.forward(self, x)
68 else:
69 for blk in self.encoder.blocks:
---> 70 x = blk(x)
72 x = self.encoder.neck(x.permute(0, 3, 1, 2))
74 # readout is changed here
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:174, in Block.forward(self, x)
171 H, W = x.shape[1], x.shape[2]
172 x, pad_hw = window_partition(x, self.window_size)
--> 174 x = self.attn(x)
175 # Reverse window partition
176 if self.window_size > 0:
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:234, in Attention.forward(self, x)
231 attn = (q * self.scale) @ k.transpose(-2, -1)
233 if self.use_rel_pos:
--> 234 attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
236 attn = attn.softmax(dim=-1)
237 x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:350, in add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
348 k_h, k_w = k_size
349 Rh = get_rel_pos(q_h, k_h, rel_pos_h)
--> 350 Rw = get_rel_pos(q_w, k_w, rel_pos_w)
352 B, _, dim = q.shape
353 r_q = q.reshape(B, q_h, q_w, dim)
File ~/anaconda3/envs/gbi-python-env/lib/python3.13/site-packages/segment_anything/modeling/image_encoder.py:292, in get_rel_pos(q_size, k_size, rel_pos)
288 x = x[:, :H, :W, :].contiguous()
289 return x
--> 292 def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293 """
294 Get relative positional embeddings according to the relative positions of
295 query and key sizes.
(...) 302 Extracted positional embeddings according to relative positions.
303 """
304 max_rel_dist = int(2 * max(q_size, k_size) - 1)
KeyboardInterrupt:
plot results
from cellpose import transforms, plot
titles = [
"Cellpose", "Nuclei", "Tissuenet", "Livecell", "YeaZ",
"Omnipose\nphase-contrast", "Omnipose\nfluorescent",
"DeepBacs"
]
plt.figure(figsize=(12,6))
ly = 400
for iex in range(len(imgs)):
img = imgs[iex].squeeze().copy()
img = np.clip(transforms.normalize_img(img, axis=0), 0, 1) # normalize images across channel axis
ax = plt.subplot(3, 8, (iex%3)*8 + (iex//3) +1)
if img[1].sum()==0:
img = img[0]
ax.imshow(img, cmap="gray")
else:
# make RGB from 2 channel image
img = np.concatenate((np.zeros_like(img)[:1], img), axis=0).transpose(1,2,0)
ax.imshow(img)
ax.set_ylim([0, min(400, img.shape[0])])
ax.set_xlim([0, min(400, img.shape[1])])
# GROUND-TRUTH = PURPLE
# PREDICTED = YELLOW
outlines_gt = utils.outlines_list(masks_true[iex])
outlines_pred = utils.outlines_list(masks_pred[iex])
for o in outlines_gt:
plt.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=0.5)
for o in outlines_pred:
plt.plot(o[:,0], o[:,1], color=[1,1,0.3], lw=0.75, ls="--")
plt.axis('off')
if iex%3 == 0:
ax.set_title(titles[iex//3])
plt.tight_layout()
plt.show()
Run Cellpose-SAM in 3D#
There are two ways to run cellpose in 3D, this cell shows both, choose which one works best for you.
First way: computes flows from 2D slices and combines into 3D flows to create masks
img_3D = io.imread("rgb_3D.tif")
# 1. computes flows from 2D slices and combines into 3D flows to create masks
masks, flows, _ = model.eval(img_3D, z_axis=0, channel_axis=1,
batch_size=32,
do_3D=True, flow3D_smooth=1)
Second way: computes masks in 2D slices and stitches masks in 3D based on mask overlap
Note stitching (with stitch_threshold > 0) can also be used to track cells over time.
# 2. computes masks in 2D slices and stitches masks in 3D based on mask overlap
print('running cellpose 2D + stitching masks')
masks_stitched, flows_stitched, _ = model.eval(img_3D, z_axis=0, channel_axis=1,
batch_size=32,
do_3D=False, stitch_threshold=0.5)
Results from 3D flows => masks computation
# DISPLAY RESULTS 3D flows => masks
plt.figure(figsize=(15,3))
for i,iplane in enumerate(np.arange(0,75,10,int)):
img0 = plot.image_to_rgb(img_3D[iplane, [1,0]].copy(), channels=[2,3])
plt.subplot(1,8,i+1)
outlines = utils.masks_to_outlines(masks[iplane])
outX, outY = np.nonzero(outlines)
imgout= img0.copy()
imgout[outX, outY] = np.array([255,75,75])
plt.imshow(imgout)
plt.title('iplane = %d'%iplane)
Results from stitching
# DISPLAY RESULTS stitching
plt.figure(figsize=(15,3))
for i,iplane in enumerate(np.arange(0,75,10,int)):
img0 = plot.image_to_rgb(img_3D[iplane, [1,0]].copy(), channels=[2,3])
plt.subplot(1,8,i+1)
outlines = utils.masks_to_outlines(masks_stitched[iplane])
outX, outY = np.nonzero(outlines)
imgout= img0.copy()
imgout[outX, outY] = np.array([255,75,75])
plt.imshow(imgout)
plt.title('iplane = %d'%iplane)