from fastai.data.external import *
Core vision
Helpers
= Image.open(TEST_IMAGE).resize((30,20)) im
Image.n_px
Image.n_px (x:PIL.Image.Image)
30*20) test_eq(im.n_px,
Image.shape
Image.shape (x:PIL.Image.Image)
20,30)) test_eq(im.shape, (
Image.aspect
Image.aspect (x:PIL.Image.Image)
30/20) test_eq(im.aspect,
Image.reshape
Image.reshape (x:PIL.Image.Image, h, w, resample=0)
resize
x
to (w,h)
Image.reshape
Image.reshape (x:PIL.Image.Image, h, w, resample=0)
resize
x
to (w,h)
12,10).shape, (12,10)) test_eq(im.reshape(
Image.to_bytes_format
Image.to_bytes_format (im:PIL.Image.Image, format='png')
Convert to bytes, default to PNG format
Image.to_bytes_format
Image.to_bytes_format (im:PIL.Image.Image, format='png')
Convert to bytes, default to PNG format
Image.to_thumb
Image.to_thumb (h, w=None)
Same as thumbnail
, but uses a copy
Image.to_thumb
Image.to_thumb (h, w=None)
Same as thumbnail
, but uses a copy
Image.resize_max
Image.resize_max (x:PIL.Image.Image, resample=0, max_px=None, max_h=None, max_w=None)
resize
x
to max_px
, or max_h
, or max_w
=20*30).shape, (20,30))
test_eq(im.resize_max(max_px=300).n_px, 294)
test_eq(im.resize_max(max_px=500, max_h=10, max_w=20).shape, (10,15))
test_eq(im.resize_max(max_px=14, max_w=15).shape, (10,15))
test_eq(im.resize_max(max_h=300, max_h=10, max_w=25).shape, (10,15)) test_eq(im.resize_max(max_px
Image.resize_max
Image.resize_max (x:PIL.Image.Image, resample=0, max_px=None, max_h=None, max_w=None)
resize
x
to max_px
, or max_h
, or max_w
Basic types
This section regroups the basic types used in vision with the transform that create objects of those types.
to_image
to_image (x)
Convert a tensor or array to a PIL int8 Image
load_image
load_image (fn, mode=None)
Open and load a PIL.Image
and convert to mode
image2tensor
image2tensor (img)
Transform image to byte tensor in c*h*w
dim order.
PILBase
PILBase ()
Base class for a Pillow Image
that can show itself and convert to a Tensor
PILBase.create
PILBase.create (fn:pathlib.Path|str|torch.Tensor|numpy.ndarray|bytes|PIL .Image.Image, **kwargs)
Return an Image from fn
Images passed to PILBase
or inherited classes’ create
as a PyTorch Tensor
, NumPy ndarray
, or Pillow Image
must already be in the correct Pillow image format. For example, uint8
, and RGB or BW for PILImage
or PILImageBW
, respectively.
PILBase.show
PILBase.show (ctx=None, **kwargs)
Show image using merge(self._show_args, kwargs)
PILImage
PILImage ()
A RGB Pillow Image
that can show itself and converts to TensorImage
PILImageBW
PILImageBW ()
A BW Pillow Image
that can show itself and converts to TensorImageBW
= PILImage.create(TEST_IMAGE)
im type(im), PILImage)
test_eq('RGB')
test_eq(im.mode, str(im), 'PILImage mode=RGB size=1200x803') test_eq(
= PILImage.create(im)
im2 type(im2), PILImage)
test_eq('RGB')
test_eq(im2.mode, str(im2), 'PILImage mode=RGB size=1200x803') test_eq(
64,64)) im.resize((
= im.show(figsize=(1,1)) ax
test_fig_exists(ax)
= TensorImage(image2tensor(im))
timg = PILImage.create(timg) tpil
64,64)) tpil.resize((
PILMask
PILMask ()
A Pillow Image
Mask that can show itself and converts to TensorMask
= PILMask.create(TEST_IMAGE)
im type(im), PILMask)
test_eq('L')
test_eq(im.mode, str(im), 'PILMask mode=L size=1200x803') test_eq(
Images
= untar_data(URLs.MNIST_TINY)
mnist = get_image_files(mnist)
fns = TEST_IMAGE_BW mnist_fn
= Transform(PILImageBW.create)
timg = timg(mnist_fn)
mnist_img 28,28))
test_eq(mnist_img.size, (assert isinstance(mnist_img, PILImageBW)
mnist_img
Segmentation masks
AddMaskCodes
AddMaskCodes (codes=None)
Add the code metadata to a TensorMask
= untar_data(URLs.CAMVID_TINY)
camvid = get_image_files(camvid/'images')
fns = fns[0]
cam_fn = camvid/'labels'/f'{cam_fn.stem}_P{cam_fn.suffix}' mask_fn
= PILImage.create(cam_fn)
cam_img 128,96))
test_eq(cam_img.size, (= Transform(PILMask.create)
tmask = tmask(mask_fn)
mask type(mask), PILMask)
test_eq(128,96)) test_eq(mask.size, (
= plt.subplots(1,3, figsize=(12,3))
_,axs =axs[0], title='image')
cam_img.show(ctx=1, ctx=axs[1], vmin=1, vmax=30, title='mask')
mask.show(alpha=axs[2], title='superimposed')
cam_img.show(ctx=axs[2], vmin=1, vmax=30); mask.show(ctx
Points
TensorPoint
TensorPoint (x, **kwargs)
Basic type for points in an image
Points are expected to come as an array/tensor of shape (n,2)
or as a list of lists with two elements. Unless you change the defaults in PointScaler
(see later on), coordinates should go from 0 to width/height, with the first one being the column index (so from 0 to width) and the second one being the row index (so from 0 to height).
This is different from the usual indexing convention for arrays in numpy or in PyTorch, but it’s the way points are expected by matplotlib or the internal functions in PyTorch like F.grid_sample
.
= TensorImage(mnist_img.resize((28,35)))
pnt_img = np.array([[0,0], [0,35], [28,0], [28,35], [9, 17]])
pnts = Transform(TensorPoint.create)
tfm = tfm(pnts)
tpnts 5,2])
test_eq(tpnts.shape, [ test_eq(tpnts.dtype, torch.float32)
= pnt_img.show(figsize=(1,1), cmap='Greys')
ctx =ctx); tpnts.show(ctx
Bounding boxes
get_annotations
get_annotations (fname, prefix=None)
Open a COCO style json in fname
and returns the lists of filenames (with maybe prefix
) and labelled bboxes.
Test get_annotations
on the coco_tiny dataset against both image filenames and bounding box labels.
= untar_data(URLs.COCO_TINY)
coco = get_annotations(coco/'train.json')
test_images, test_lbl_bbox = json.load(open(coco/'train.json'))
annotations = map(lambda x:L(x),annotations.values())
categories, images, annots
'file_name'))
test_eq(test_images, images.attrgot(
def bbox_lbls(file_name):
= images.filter(lambda img:img['file_name']==file_name)[0]
img = annots.filter(lambda a:a['image_id'] == img['id'])
bbs = {k['id']:k['name'] for k in categories}
i2o = [i2o[cat] for cat in bbs.attrgot('category_id')]
lbls = [[bb[0],bb[1], bb[0]+bb[2], bb[1]+bb[3]] for bb in bbs.attrgot('bbox')]
bboxes return [bboxes, lbls]
for idx in random.sample(range(len(images)),5):
test_eq(test_lbl_bbox[idx], bbox_lbls(test_images[idx]))
TensorBBox
TensorBBox (x, **kwargs)
Basic type for a tensor of bounding boxes in an image
Bounding boxes are expected to come as tuple with an array/tensor of shape (n,4)
or as a list of lists with four elements and a list of corresponding labels. Unless you change the defaults in PointScaler
(see later on), coordinates for each bounding box should go from 0 to width/height, with the following convention: x1, y1, x2, y2 where (x1,y1) is your top-left corner and (x2,y2) is your bottom-right corner.
We use the same convention as for points with x going from 0 to width and y going from 0 to height.
LabeledBBox
LabeledBBox (items=None, *rest, use_list=False, match=None)
Basic type for a list of bounding boxes in an image
= untar_data(URLs.COCO_TINY)
coco = get_annotations(coco/'train.json')
images, lbl_bbox =2
idx= coco/'train'/images[idx],lbl_bbox[idx]
coco_fn,bbox = timg(coco_fn) coco_img
= LabeledBBox(TensorBBox(bbox[0]), bbox[1])
tbbox = coco_img.show(figsize=(3,3), cmap='Greys')
ctx =ctx); tbbox.show(ctx
Basic Transforms
Unless specifically mentioned, all the following transforms can be used as single-item transforms (in one of the list in the tfms
you pass to a TfmdDS
or a Datasource
) or tuple transforms (in the tuple_tfms
you pass to a TfmdDS
or a Datasource
). The safest way that will work across applications is to always use them as tuple_tfms
. For instance, if you have points or bounding boxes as targets and use Resize
as a single-item transform, when you get to PointScaler
(which is a tuple transform) you won’t have the correct size of the image to properly scale your points.
encodes
encodes (o:__main__.PILMask)
encodes
encodes (o:__main__.PILMask)
Any data augmentation transform that runs on PIL Images must be run before this transform.
= ToTensor()
tfm print(tfm)
print(type(mnist_img))
print(type(tfm(mnist_img)))
ToTensor:
encodes: (PILMask,object) -> encodes
(PILBase,object) -> encodes
(PILMask,object) -> encodes
(PILBase,object) -> encodes
decodes:
<class '__main__.PILImageBW'>
<class 'fastai.torch_core.TensorImageBW'>
= ToTensor()
tfm 1,28,28))
test_eq(tfm(mnist_img).shape, (type(tfm(mnist_img)), TensorImageBW)
test_eq(96,128))
test_eq(tfm(mask).shape, (type(tfm(mask)), TensorMask) test_eq(
Let’s confirm we can pipeline this with PILImage.create
.
= Pipeline([PILImageBW.create, ToTensor()])
pipe_img = pipe_img(mnist_fn)
img type(img), TensorImageBW)
test_eq(=(1,1)); pipe_img.show(img, figsize
def _cam_lbl(x): return mask_fn
= Datasets([cam_fn], [[PILImage.create, ToTensor()], [_cam_lbl, PILMask.create, ToTensor()]])
cam_tds 0); show_at(cam_tds,
To work with data augmentation, and in particular the grid_sample
method, points need to be represented with coordinates going from -1 to 1 (-1 being top or left, 1 bottom or right), which will be done unless you pass do_scale=False
. We also need to make sure they are following our convention of points being x,y coordinates, so pass along y_first=True
if you have your data in an y,x format to add a flip.
This transform needs to run on the tuple level, before any transform that changes the image size.
PointScaler
PointScaler (do_scale=True, y_first=False)
Scale a tensor representing points
To work with data augmentation, and in particular the grid_sample
method, points need to be represented with coordinates going from -1 to 1 (-1 being top or left, 1 bottom or right), which will be done unless you pass do_scale=False
. We also need to make sure they are following our convention of points being x,y coordinates, so pass along y_first=True
if you have your data in an y,x format to add a flip.
This transform automatically grabs the sizes of the images it sees before a TensorPoint
object and embeds it in them. For this to work, those images need to be before any points in the order of your final tuple. If you don’t have such images, you need to embed the size of the corresponding image when creating a TensorPoint
by passing it with sz=...
.
def _pnt_lbl(x): return TensorPoint.create(pnts)
def _pnt_open(fn): return PILImage(PILImage.create(fn).resize((28,35)))
= Datasets([mnist_fn], [_pnt_open, [_pnt_lbl]])
pnt_tds = TfmdDL(pnt_tds, bs=1, after_item=[PointScaler(), ToTensor()]) pnt_tdl
10) test_eq(pnt_tdl.after_item.c,
= pnt_tdl.one_batch()
x,y #Scaling and flipping properly done
#NB: we added a point earlier at (9,17); formula below scales to (-1,1) coords
0], tensor([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.], [9/14-1, 17/17.5-1]]))
test_close(y[= pnt_tdl.decode_batch((x,y))[0]
a,b float())
test_eq(b, tensor(pnts).#Check types
type(x), TensorImage)
test_eq(type(y), TensorPoint)
test_eq(type(a), TensorImage)
test_eq(type(b), TensorPoint)
test_eq(28,35)) #Automatically picked the size of the input test_eq(b.img_size, (
=(2,2), cmap='Greys'); pnt_tdl.show_batch(figsize
BBoxLabeler
BBoxLabeler (enc=None, dec=None, split_idx=None, order=None)
Delegates (__call__
,decode
,setup
) to (encodes
,decodes
,setups
) if split_idx
matches
decodes
decodes (x:__main__.LabeledBBox)
decodes
decodes (x:__main__.TensorBBox)
encodes
encodes (x:__main__.TensorBBox)
def _coco_bb(x): return TensorBBox.create(bbox[0])
def _coco_lbl(x): return bbox[1]
= Datasets([coco_fn], [PILImage.create, [_coco_bb], [_coco_lbl, MultiCategorize(add_na=True)]], n_inp=1)
coco_tds = TfmdDL(coco_tds, bs=1, after_item=[BBoxLabeler(), PointScaler(), ToTensor()]) coco_tdl
=True) Categorize(add_na
Categorize -- {'vocab': None, 'sort': True, 'add_na': True}:
encodes: (object,object) -> encodes
decodes: (object,object) -> decodes
coco_tds.tfms
(#3) [Pipeline: PILBase.create,Pipeline: _coco_bb,Pipeline: _coco_lbl -> MultiCategorize -- {'vocab': None, 'sort': True, 'add_na': True}]
x,y,z
(PILImage mode=RGB size=128x128,
TensorBBox([[-0.9011, -0.4606, 0.1416, 0.6764],
[ 0.2000, -0.2405, 1.0000, 0.9102],
[ 0.4909, -0.9325, 0.9284, -0.5011]]),
TensorMultiCategory([1, 1, 1]))
= coco_tdl.one_batch()
x,y,z 0], -1+tensor(bbox[0])/64)
test_close(y[0], tensor([1,1,1]))
test_eq(z[= coco_tdl.decode_batch((x,y,z))[0]
a,b,c 0]).float())
test_close(b, tensor(bbox[
test_eq(c.bbox, b)1])
test_eq(c.lbl, bbox[
#Check types
type(x), TensorImage)
test_eq(type(y), TensorBBox)
test_eq(type(z), TensorMultiCategory)
test_eq(type(a), TensorImage)
test_eq(type(b), TensorBBox)
test_eq(type(c), LabeledBBox)
test_eq(128,128)) test_eq(y.img_size, (
; coco_tdl.show_batch()