#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CTSegNet is more than a 2D CNN model - it's a 3D Segmenter that uses 2D CNNs. The set_utils.py defines the Segmenter class that wraps over a keras U-net-like model (defined by models.py), integrating 3D slicing and 2D patching functions to enable the 3D-2D-3D conversations in the segmentation workflow. To slice a 3D volume, manipulations such as 45 deg rotations, orthogonal slicing, patch extraction and stitching are performed.
"""
import sys
import os
# line 13 empty for good luck
import numpy as np
import pandas as pd
import re
import ast
import h5py
import cv2
import time
import tensorflow as tf
from tensorflow.keras.models import load_model
from ct_segnet.data_utils import patch_maker as PM
from ct_segnet.data_utils.data_io import Parallelize
from ct_segnet.model_utils.losses import custom_objects_dict
VERBOSE = False
def message(_str):
if VERBOSE:
print(_str)
return
[docs]class Segmenter():
"""
The Segmenter class wraps over a keras model, integrating 3D slicing and 2D patching functions to enable the 3D-2D-3D conversations in the segmentation workflow.
model: tf.keras.model
keras model with input shape = out shape = (ny, nx, 1)
model_filename : str
path to keras model file (e.g. "model_1.h5")
model_name : str
(optional) just a name for the model
GPU_mem_limit : float
max limit of GPU memory to use
"""
def __init__(self, model_filename = None, model = None, model_name = "unknown", weight_file_name = None, GPU_mem_limit = 16.0):
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=GPU_mem_limit*1000.0)])
except RuntimeError as e:
print(e)
# if you get serialization errors using load_model
if weight_file_name is not None:
self.model = model
self.model.load_weights(weight_file_name) # note that weights can be loaded from a full save, not only from save_weights file
return
if model is not None:
self.model = model
self.model_name = model_name
else:
self.model_name = os.path.split(model_filename)[-1].split('.')[0]
self.model = load_model(model_filename, custom_objects = custom_objects_dict)
[docs] def seg_image(self, s, max_patches = None, overlap = None):
"""
Test the segmenter on arbitrary sized 2D image. This method extracts patches of shape same as the input shape of 2D CNN, segments them and stitches them back to form the original image.
max_patches : tuple
(my, mx) are # of patches along Y, X in image
s : numpy.array
greyscale image slice of shape (ny, nx)
overlap : tuple or int
number of overlapping pixels between patches
"""
# Handle patching parameter inputs
patch_size = self.model.output_shape[1:-1]
if type(max_patches) is not tuple:
max_patches = (max_patches, max_patches)
if type(overlap) is not tuple:
overlap = (overlap, overlap)
overlap = (0 if max_patches[0] == 1 else overlap[0],\
0 if max_patches[1] == 1 else overlap[1])
# Resize images
orig_shape = s.shape
s = cv2.resize(s, (max_patches[1]*patch_size[1] - overlap[1],\
max_patches[0]*patch_size[0] - overlap[0]))
# Make patches
downres_shape = s.shape
steps = PM.get_stepsize(downres_shape, patch_size)
s = PM.get_patches(s, patch_size = patch_size, steps = steps)
# The dataset now has shape: (ny, nx, py, px). ny, nx are # of patches, and py, px is patch_shape.
# Reshape this dataset into (n, py, px) where n = ny*nx. Trust numpy to preserve order. lol.
dataset_shape = s.shape
s = s.reshape((-1,) + patch_size)
# Predict using the model.
s = self.model.predict(s[...,np.newaxis])
s = s[...,0]
# Now, reshape the data back...
s = s.reshape(dataset_shape)
# Reconstruct from patches...
s = PM.recon_patches(s, img_shape = downres_shape, steps = steps)
# Finally, resize the images to the original shape of slices... This will result in some loss of resolution...
s = cv2.resize(s, (orig_shape[1], orig_shape[0]))
# outputs: segmented image of same shape as input image p
return np.asarray(np.round(s)).astype(np.uint8)
[docs] def seg_chunk(self, p, max_patches = None, overlap = None,\
nprocs = None, arr_split = 1, arr_split_infer = 1):
"""
Segment a volume of shape (nslices, ny, nx). The 2D keras model passes\
along nslices, segmenting images (ny, nx) with a patch size defined by input \
to the model
max_patches: tuple
(my, mx) are # of patches along Y, X in image (ny, nx)
overlap : tuple or int
number of overlapping pixels between patches
nprocs : int
number of CPU processors for multiprocessing Pool
arr_split : int
breakdown chunk into arr_split number of smaller chunks
"""
# Handle patching parameter inputs
patch_size = self.model.output_shape[1:-1]
if type(max_patches) is not tuple:
max_patches = (max_patches, max_patches)
if type(overlap) is not tuple:
overlap = (overlap, overlap)
overlap = (0 if max_patches[0] == 1 else overlap[0],\
0 if max_patches[1] == 1 else overlap[1])
# Resize images
orig_shape = p[0].shape
p = np.asarray([cv2.resize(p[ii], (max_patches[1]*patch_size[1] - overlap[1],\
max_patches[0]*patch_size[0] - overlap[0]))\
for ii in range(p.shape[0])])
# Make patches
message("Making patches...")
message("\tCurrent d shape:" + str(np.shape(p)))
downres_shape = p[0].shape
steps = PM.get_stepsize(downres_shape, patch_size)
p = Parallelize(p, PM.get_patches, procs = nprocs, \
patch_size = patch_size, steps = steps)
p = np.asarray(p)
# The dataset now has shape: (nslices, ny, nx, py, px),
# where ny, nx are # of patches, and py, px is patch_shape.
# Reshape this dataset into (n, py, px) where n = nslices*ny*nx.
dataset_shape = p.shape
p = p.reshape((-1,) + patch_size)
# Predict using the model.
message("Running predictions using model...")
message("\tCurrent d shape:" + str(np.shape(p)))
p = np.array_split(p, arr_split_infer)
for jj in range(len(p)):
p[jj] = self.model.predict(p[jj][...,np.newaxis])[...,0]
p[jj] = np.round(p[jj])
p = np.concatenate(p, axis = 0)
p = p.astype(np.uint8) # typecasting
# Now, reshape the data back...
p = p.reshape(dataset_shape)
p = [p[ii] for ii in range(p.shape[0])]
# Reconstruct from patches...
message("Reconstructing from patches...")
message("\tCurrent d shape:" + str(np.shape(p)))
p = np.array_split(p, arr_split)
p = [np.asarray(Parallelize(p[ii], PM.recon_patches,\
img_shape = downres_shape,\
steps = steps, procs = nprocs\
)) for ii in range(arr_split)]
p = np.concatenate(p, axis = 0)
# Finally, resize the images to the original shape of slices... This will result in some loss of resolution...
message("Resizing images to original slice size...")
message("\tCurrent d shape:" + str(np.shape(p)))
p = np.asarray([cv2.resize(p[ii], (orig_shape[1], orig_shape[0]))\
for ii in range(p.shape[0])])
return p
def get_repadding(crops, d_shape):
"""
Returns
-------
tuple
padding values to restore 3D np array after it was cropped.
Parameters
----------
crops : list
3 tuples in a list [(nz1,nz2), (ny1,ny2), (nx1,nx2)]
d_shape : tuple
original shape of 3D array
"""
pads = []
for idx, crop in enumerate(crops):
pad = [0,0]
if (crop[0] is not None):
if crop[0] >= 0:
pad[0] = abs(crop[0])
elif crop[0] < 0:
pad[0] = d_shape[idx] - abs(crop[0])
if crop[1] is not None:
if crop[1] >= 0:
pad[1] = d_shape[idx] - abs(crop[1])
elif crop[1] < 0:
pad[1] = abs(crop[1])
pads.append(tuple(pad))
return tuple(pads)
def _rotate(imgs, angle):
"""
Just a wrapper for cv2's affine transform for rotating an image about center
Parameters
----------
imgs : np.array
volume or series of images (n, ny, nx)
angle : float
value to rotate image about center, along (ny,nx)
"""
rows, cols = imgs[0].shape
M = cv2.getRotationMatrix2D((cols/2,rows/2), angle,1)
return np.asarray([cv2.warpAffine(imgs[iS],M,(cols,rows)) for iS in range(len(imgs))])
[docs]def process_data(p, segmenter, preprocess_func = None, max_patches = None,\
overlap = None, nprocs = None, rot_angle = 0.0, slice_axis = 0,\
crops = None, arr_split = 1, arr_split_infer = 1):
"""
Segment a volume of shape (nz, ny, nx). The 2D keras model passes
along either axis (0,1,2), segmenting images with a patch size defined by input
to the model in the segmenter class.
Parameters
----------
max_patches : tuple
(my, mx) are # of patches along Y, X in image (ny, nx)
overlap : tuple or int
number of overlapping pixels between patches
nprocs : int
number of CPU processors for multiprocessing Pool
arr_split : int
breakdown chunk into arr_split number of smaller chunks
slice_axis : int
(0,1,2); axis along which to draw slices
crops : list
list of three tuples; each tuple (start, stop) will define a python slice for the respective axis
rot_angle : float
(degrees) rotate volume around Z axis before slicing along any given axis. Note this is redundant if slice_axis = 0
nprocs : int
number of CPU processors for multiprocessing Pool
arr_split : int
breakdown chunk into arr_split number of smaller chunks
preprocess_fun : func
pass a preprocessing function that applies a 2D filter on an image
"""
if nprocs is None:
nprocs = 4
if p.ndim != 3:
raise ValueError("Invalid dimensions for 3D data.")
message("Orienting, rotating and padding as requested...")
# Rotate the volume along axis 0, if requested
if rot_angle > 0.0:
p = _rotate(p, rot_angle)
if crops is not None:
pads = get_repadding(crops, p.shape)
p = p[slice(*crops[0]), slice(*crops[1]), slice(*crops[2])]
# Orient the volume such that the first axis is the direction in which to slice through...
p = np.moveaxis(p, slice_axis, 0)
message("\tDone")
# Preprocess function
if preprocess_func is not None:
# print("\tPreprocessing on XY mapping...")
p = preprocess_func(p)
# Run the segmenter algorithm
p = segmenter.seg_chunk(p, max_patches = max_patches, overlap = overlap, nprocs = nprocs, arr_split = arr_split, arr_split_infer = arr_split_infer)
message("Re-orienting, rotating and padding back original size...")
# Re-orient the volume such that the first axis is the vertical axis...
p = np.moveaxis(p, 0, slice_axis)
# Pad the volume to bring it back to original dimensions
if crops is not None:
p = np.pad(p, pads, 'constant', constant_values = 0)
# Rotate the volume along axis 0, back to its original state
if rot_angle > 0.0:
p = _rotate(p, -rot_angle)
message("\tDone")
return p.astype(np.uint8)
if __name__ == "__main__":
message("\n" + "#"*50 + "\n")
message("Welcome to CTSegNet: AI-based 3D Segmentation\n")
message("\n" + "#"*50 + "\n")