You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/notebooks/GeneralizationToImageNetV2....

1789 lines
17 MiB

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "GeneralizationToImageNetV2",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/rwightman/pytorch-image-models/blob/master/notebooks/GeneralizationToImageNetV2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NxIFjw_uOaHa",
"colab_type": "text"
},
"source": [
"# How Do ImageNet-1k Models Generalize to ImageNet-V2?\n",
"\n",
"I was recently [benchmarking the runtime performance](https://colab.research.google.com/github/rwightman/pytorch-image-models/blob/master/notebooks/EffResNetComparison.ipynb) of some models in Colab on the [ImageNet-V2 dataset](https://github.com/modestyachts/ImageNetV2) and noticed something interesting: the [Facebook WSL Instagram pretrained ResNeXt](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) model had a smaller accuracy gap than any model I'd seen to date. I decided to dig in a bit more in this notebook and compare the rest of the WSL models and a reasonable sampling of other models wrt to their generalization gap on ImageNet-1k vs ImageNet-V2."
]
},
{
"cell_type": "code",
"metadata": {
"id": "gncNNhwIOMma",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "631f1f6a-4959-4bff-9e11-51d61b33fa5b"
},
"source": [
"!pip install timm"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: timm in /usr/local/lib/python3.6/dist-packages (0.1.8)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from timm) (0.3.0)\n",
"Requirement already satisfied: torch>=1.0 in /usr/local/lib/python3.6/dist-packages (from timm) (1.1.0)\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (4.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (1.16.4)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (1.12.0)\n",
"Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision->timm) (0.46)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xBDgzTWROeED",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 306
},
"outputId": "76c9e437-2dda-41aa-c61d-e42553ccf443"
},
"source": [
"# For our convenience, take a peek at what we're working with\n",
"!nvidia-smi"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Sat Jul 6 22:42:48 2019 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 418.67 Driver Version: 410.79 CUDA Version: 10.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 48C P8 16W / 70W | 0MiB / 15079MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jXEyxmp_OoLF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "dbfc984a-9039-42ca-d96c-3d42ad23337b"
},
"source": [
"# Import the core modules, check which GPU we end up with and scale batch size accordingly\n",
"import torch\n",
"torch.backends.cudnn.benchmark = True\n",
"\n",
"import timm\n",
"from timm.data import *\n",
"from timm.utils import *\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import pynvml\n",
"from collections import OrderedDict\n",
"import logging\n",
"import time\n",
"\n",
"def log_gpu_memory():\n",
" handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
" info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n",
" info.free = round(info.free / 1024**2)\n",
" info.used = round(info.used / 1024**2)\n",
" logging.info('GPU memory free: {}, memory used: {}'.format(info.free, info.used))\n",
" return info.used\n",
"\n",
"def get_gpu_memory_total():\n",
" handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
" info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n",
" info.total = round(info.total / 1024**2)\n",
" return info.total\n",
"\n",
"setup_default_logging()\n",
" \n",
"print('PyTorch version:', torch.__version__)\n",
"if torch.cuda.is_available():\n",
" print('CUDA available')\n",
" device='cuda'\n",
"else:\n",
" print('CUDA is not available')\n",
" device='cpu'\n",
"\n",
"BATCH_SIZE = 128\n",
"if device == 'cuda':\n",
" pynvml.nvmlInit()\n",
" log_gpu_memory()\n",
" total_gpu_mem = get_gpu_memory_total()\n",
" HAS_T4 = False\n",
" if total_gpu_mem > 12300:\n",
" HAS_T4 = True\n",
" logging.info('Running on a T4 GPU or other with > 12GB memory, setting batch size to {}'.format(BATCH_SIZE))\n",
" else:\n",
" BATCH_SIZE = 64\n",
" logging.info('Running on a K80 GPU or other with < 12GB memory, batch size set to {}'.format(BATCH_SIZE))\n",
"\n",
" "
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"GPU memory free: 15069, memory used: 11\n",
"Running on a T4 GPU or other with > 12GB memory, setting batch size to 128\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"PyTorch version: 1.1.0\n",
"CUDA available\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aZPg2V_Ft6mk",
"colab_type": "text"
},
"source": [
"# The Dataset\n",
"\n",
"ImageNet-V2 (https://github.com/modestyachts/ImageNetV2) is a useful collection of 3 ImageNet-like validation sets that have been collected more recently, 10 years after the original ImageNet.\n",
"\n",
"Aside from being conveniently smaller and easier to deploy in a notebook, it's a useful test set to compare how models might generalize beyond the original ImageNet-1k data. We're going to use the 'Matched Frequency' version of the dataset. You can read more about the dataset in the paper by its creators (Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, Vaishaal Shankar): [\"Do ImageNet Classifiers Generalize to ImageNet?\"](http://people.csail.mit.edu/ludwigs/papers/imagenet.pdf)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rdHLlfERO6d_",
"colab_type": "code",
"colab": {}
},
"source": [
"# Download and extract the dataset (note it's not actually a gz like the file says)\n",
"if not os.path.exists('./imagenetv2-matched-frequency'):\n",
" !curl -s https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz | tar x\n",
"dataset = Dataset('./imagenetv2-matched-frequency/')\n",
"assert len(dataset) == 10000"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EH13ZwRhqDAi",
"colab_type": "text"
},
"source": [
"Let's take a look at some random images in the dataset...\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QhmlwJO7VxlC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "93522b51-0dae-41a2-96ab-b25a282ff543"
},
"source": [
"from torchvision.utils import make_grid\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def show_img(ax, img):\n",
" npimg = img.numpy()\n",
" ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='bicubic')\n",
"\n",
"fig = plt.figure(figsize=(8, 16), dpi=100)\n",
"ax = fig.add_subplot('111')\n",
"num_images = 4*8\n",
"images = []\n",
"dataset.transform = transforms.Compose([\n",
" transforms.Resize(320),\n",
" transforms.CenterCrop(320),\n",
" transforms.ToTensor()])\n",
"for i in np.random.permutation(np.arange(len(dataset)))[:num_images]:\n",
" images.append(dataset[i][0])\n",
" \n",
"grid_img = make_grid(images, nrow=4, padding=10, normalize=True, scale_each=True)\n",
"show_img(ax, grid_img) \n",
"\n"
],
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqIAAAToCAYAAADAEqLDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvcmvLNuS5vUzW417ROx9zrnduzfz\nvcysJktCBaIQJRVCTJAKJJgwReJfgAFThjApRjWh5oghsxqCEokhqoSkG1BVSlGZld17tz3NbiLc\nfa1lDGy5R+xzzn35msy8V7wwad97duzwbvlyW5999pm5mBlXu9rVrna1q13tale72l+16Xd9Ale7\n2tWudrWrXe1qV/vVtCsQvdrVrna1q13tale72ndiVyB6tatd7WpXu9rVrna178SuQPRqV7va1a52\ntatd7WrfiV2B6NWudrWrXe1qV7va1b4TuwLRq13tale72tWudrWrfSd2BaJXu9rVrna1q13talf7\nTuwKRK92tatd7WpXu9rVrvad2BWIXu1qV7va1a52tatd7TuxKxC92tWudrWrXe1qV7vad2LfayAq\nIv+piPyhiJxE5J+IyN/7rs/pale72tWudrWrXe1qfzH2vQWiIvIfA/8Q+C+BfxP4v4D/UUR+8J2e\n2NWudrWrXe1qV7va1f5CTMzsuz6H95qI/BPgfzWz/6z/rsAfA/+Nmf3XP8P2Avw6cPeXeqJXu9rV\nrna1q13talcDuAX+zH4OcBn/Ek/mFzYRycDfBf7B+pmZNRH5n4B/+1u2GYDh4qNfA/7ZX+Z5Xu1q\nV7va1a52tatd7Yn9CPjTn/XL39fU/MdAAD5/6/PPgc++ZZv/Anh98XMFoVe72tWudrWrXe1qf7X2\nc2Wiv69A9BexfwA8v/j50Xd7Ole72tWudrWrXe1qV/tp9r1MzQNfARX49K3PPwV+8r4NzGwCpvV3\nl4g+NVk/F8Dg8itmhl18b/3vO/uQ9ftPfz/viL4f+9bzOB/z7f3ZT/3++239/vvlGGbb5Z63+DkP\ncan0kLcO+fZRt+sB1hHdjv9tihGRpwd5z/Hff87S923f8vu7332y34v//rz2zulq+PaBlXf+8Ysc\n0U/1yQ3on//U3392+2XO7tKePkd//sGkb/TunXh7T28/oz9lp+/79GKaydtfsLc3ePvDb7PL5+/9\n29jFLXl7Fj6ZqwZWy5PJJSJoiO+ZQhdj0oyn990wBBV56k++dR5av+RvuWP29EzN3F+ez/H9+3xy\naWa01tadIYif8Xu3Xff/jnd5//m971tvX4sIqvp+b2mXvkqezs2L/ayn+rM+Jz/Tc/BXZK01lmV5\n8pnIT1+fvluzJ//7meydS5H3/0kEQd6aIr/gOPSdyMWz5nO98rNKJDdc8ovat7mq9z5b8t5jmT/U\n5zVxfSbsLR913vGTbX/ZWqPvJRA1s1lEfg/4+8A/hq1Y6e8D/+gX2acAHz/b8YPne2IQxCopKjEm\nSmncPc48LgVRIYoQghJTRFsjiu+gWSDGiAHTqaDBGIdEVKWYYa3SaqFU4zhP5BS53Y2kkCgoJqA0\nWq2UslBrJY87WjOm6YRYY8y5T0wj5wEJiWmeMTNiDNRaabWRc6Q2I6QMIpweHqAZQQMa3MmXpVBb\nBYPahFobMSophQ14h6BEFVptfS1TEPV/hsjcGvPxRMDQKARAxfdfDUwUTCm10SiM+4wOkYflyONy\nojUjhEBtxjIXrDVo1tdaJaaEBqXVRi0VqwYNVBRRYS6zn7dGck6YNcwgaET7faplptRKSgMhRObZ\nx1ZFNpAjKmgfGwmCWaO15pcLLLWx1IKpoeBjKEZdDGlKiIov2MrjY+X16+MZLww79G/96/DiY59r\nKojqNu+UDihU/XPxPZkI29K8OrE+Ph4vyYbR/b5XVPt+RM4PvyggNDOaGaiiQbHVSar6dQOlFmq/\n1ypCECFgBCCFSJQA1oFJEJr4ImaAtAa1oc2IqgRVTH1lFvHjV4O5VhpGUPGxFCVoAFG/zO7HVISE\noGZYK7TqYKXWSm2VZrX7UvVxakYVQIUIRJo7zwaoYhL87+K6niRCCrrNE0R8njYP+IIKKQSiqD8j\nZphACIGoAQyaNZoZhlHMKNagj1sOgRiCf6c2oggYLNUoNCp+71qfCCJ+X9aFozY/l4rPA05HXv8f\n/wvHP/ujzW99/OkP+dt/5++x2x0cTEmfxyL9+Vj45usvmY9HxpQIwOM8UxE++OgTPvroE2IMCObH\nVvVZrIIGQdWwWkhBGXLo0/ASdApmUEplKY1mwv39I3d3j5xOE4Jw++yGlCLWHBQ3a32+FqxVall4\n9fIbvvyzP2N6uCeKkFOiAmkcGQ8HNAQ/rhm1Vu7v7jg+PmJmqAgxRvcjFQzFn9L+nHTQqihRAyll\nNEQqDRNBNHB49ozf/Bt/nd3NgdaBdKuVea5M08JSKmaQkvuZkEL/d0SDEAMM0dhFYZA/L5VoNKCY\noJjPi+/Y/uAP/oDf+Z3f4c2bNwCowqcf7vjg2ej+fAvQLqD6Fjxdmrwd85wjSbn4w7eSB+f9nAmf\n9iSwvwQ26zplZlhrmDUEQ1ZfsgIrWYkf9R/oz0hEVYkhkmJiyCPDMLAbd+RhQFVRfzgvwOk5LPH5\nLxhCFXEfFiISIikPxJSJeWC33xFjBBEe7+/5l//in/Pw+muwenFhl3jxPIAhKIfDyDgmghpGQ1f/\nv47UOpjbGOt5jD2i8ACRfv2qTo4QEAlo8GcCiWjwMTEzSi3M08w8HTkd75mniXmZWOaJssycpoll\nKYD6+WjY9i8itNb4yRdf8+U3r34pMPq9BKLd/iHw34nI/wb8LvCfAwfgv/1FdiYq/NoHt/ydv/YJ\nQxKUwm5MBA0cjwtfvHzk5bwQcmRESCmQh0RujZshggYeT42QfMge3szEbHz84S1DjDwsM7TGMh95\nPC28eTzy7NmBT5/dMoSBEwETQa3QyswynVhq5eb5C5ZivHnzGurCB7c3WG3EAM9fPEfTjjdv3lBb\nIQ+ZaZpptXJz2HFaKml/g8bA159/jsyN/TgQorKUwmmaaNVYlsJSjLlUxhzZ772mqzQjpcBhTNSl\nQBPQgGhgaQYh8dAq9y9fMSDELIxRiKLMtVEMCBkkcTrNSCzsn4+k53te2wM/fvMVj6eJECJmxjwV\nrDRaqdRiCJGYIxoCtRTqUqlTI1ggp4FG5TSfmOZKCpHnz24oy0JZGilkMCEEQaWxzAvjuMcIvkAa\n0Fqnbxws+Rgm8hhBHPSECCFGTnPl7viIaSWlSMhKtYItgrSAo+OKNeXrL2fu7k7U2t3VuCf/a/8W\n+tf/toPQ4IBDVEgiBIQYAiklTB3cVVGqqJ+fOlipHXBKWRyQr0A/KEup1FIJKsQVzHSniEYQDy6K\nNdCAxkBbF+KYCdqQOiFLheaONQgkILVGwtilTI4DoTaigo6JCWOq1ReApRKWyiDKIUXGoFgQSnBH\n2BoszTi1ytwqQSAKBIQcMyEkXwEVYlCCKLGBlIm2nCi1siwFq36dIYBY5dRgbsEDH4WoMNBI5iBn\nqdA0UkRpIRBSYFAYEIYORtdF4lQq1YwQE4dxZB8jGbAGxQwJgRAjQbUv0L4ANjOmVrhfJooYQ868\nSJkxROoye1AZAhTjbq48WuNk5osXDhpN/D43GrU1rEFrxmJGQwj3byhffc7xx3/swYAIn/3wt/j3\n/6P/hA8/+hRVB/QxRqIGNCrH6cjv//4/483XX/FiyGQRvnl44NGE3/it3+a3/9a/wrPbG3JQ4pCo\nIlSDlDMpCzEa1JmbXeTmkBCzjcE180eoGRyPJ45TIcSRN3cTX379ij/50y+oS+PXP/uYnJVSKq0Y\ntRTmZaLME63MnB4f+Be//8/5f/733+PN5z8htcYuZwrG4cULPv7Rj8i7kYo/k2Ve+Mmf/Rlffv55\nD2SV/T6TU6YUa
"text/plain": [
"<Figure size 800x1600 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "j7DIVWU2Pcoz",
"colab_type": "code",
"colab": {}
},
"source": [
"# a basic validation routine and runner that configures each model and loader\n",
"from timm.models import TestTimePoolHead\n",
"\n",
"def validate(model, loader, criterion=None, device='cuda'):\n",
" # metrics\n",
" batch_time = timm.utils.AverageMeter()\n",
" losses = AverageMeter()\n",
" top1 = AverageMeter()\n",
" top5 = AverageMeter()\n",
" \n",
" # for collecting per sample prediction/loss details\n",
" losses_val = []\n",
" top5_idx = []\n",
" top5_val = []\n",
" \n",
" end = time.time()\n",
" with torch.no_grad():\n",
" for i, (input, target) in enumerate(loader):\n",
" target = target.to(device)\n",
" input = input.to(device)\n",
" output = model(input)\n",
" \n",
" if criterion is not None:\n",
" loss = criterion(output, target)\n",
" if not loss.size():\n",
" losses.update(loss.item(), input.size(0))\n",
" else:\n",
" # only bother collecting top5 we're also collecting per-example loss\n",
" output = output.softmax(1)\n",
" top5v, top5i = output.topk(5, 1, True, True)\n",
" top5_val.append(top5v.cpu().numpy())\n",
" top5_idx.append(top5i.cpu().numpy())\n",
" losses_val.append(loss.cpu().numpy())\n",
" losses.update(loss.mean().item(), input.size(0))\n",
" \n",
" prec1, prec5 = timm.utils.accuracy(output, target, topk=(1, 5))\n",
" top1.update(prec1.item(), input.size(0))\n",
" top5.update(prec5.item(), input.size(0))\n",
"\n",
" batch_time.update(time.time() - end)\n",
" end = time.time()\n",
"\n",
" if i % 20 == 0:\n",
" print('Test: [{0}/{1}]\\t'\n",
" 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \\t'\n",
" 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n",
" 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n",
" i, len(loader), batch_time=batch_time,\n",
" rate_avg=input.size(0) / batch_time.avg,\n",
" top1=top1, top5=top5))\n",
"\n",
" results = OrderedDict(\n",
" top1=top1.avg, top1_err=100 - top1.avg,\n",
" top5=top5.avg, top5_err=100 - top5.avg,\n",
" )\n",
" if criterion is not None:\n",
" results['loss'] = losses.avg\n",
" if len(top5_idx):\n",
" results['top5_val'] = np.concatenate(top5_val, axis=0)\n",
" results['top5_idx'] = np.concatenate(top5_idx, axis=0)\n",
" if len(losses_val):\n",
" results['losses_val'] = np.concatenate(losses_val, axis=0)\n",
" print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(\n",
" results['top1'], results['top1_err'], results['top5'], results['top5_err']))\n",
" return results\n",
"\n",
"\n",
"def runner(model_args, dataset, device='cuda', collect_loss=False):\n",
" model_name = model_args['model']\n",
" model = timm.create_model(model_name, pretrained=True)\n",
" ttp = False\n",
" if 'ttp' in model_args and model_args['ttp']:\n",
" ttp = True\n",
" logging.info('Applying test time pooling to model')\n",
" model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])\n",
" model = model.to(device)\n",
" model.eval()\n",
" if HAS_T4:\n",
" model = model.half()\n",
"\n",
" data_config = timm.data.resolve_data_config(model_args, model=model, verbose=True)\n",
" \n",
" loader = timm.data.create_loader(\n",
" dataset,\n",
" input_size=data_config['input_size'],\n",
" batch_size=BATCH_SIZE,\n",
" use_prefetcher=True,\n",
" interpolation='bicubic',\n",
" mean=data_config['mean'],\n",
" std=data_config['std'],\n",
" fp16=HAS_T4,\n",
" crop_pct=1.0 if ttp else data_config['crop_pct'],\n",
" num_workers=2)\n",
"\n",
" criterion = None\n",
" if collect_loss:\n",
" criterion = torch.nn.CrossEntropyLoss(reduction='none').to(device)\n",
" results = validate(model, loader, criterion, device)\n",
" \n",
" # cleanup checkpoint cache to avoid running out of disk space\n",
" shutil.rmtree(os.path.join(os.environ['HOME'], '.cache', 'torch', 'checkpoints'), True)\n",
" \n",
" # add some non-metric values for charting / comparisons\n",
" results['model'] = model_name\n",
" results['img_size'] = data_config['input_size'][-1]\n",
"\n",
" # create key to identify model in charts\n",
" key = [model_name, str(data_config['input_size'][-1])]\n",
" if ttp:\n",
" key += ['ttp']\n",
" key = '-'.join(key)\n",
" return key, results\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wyOyVOehQqXL",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "1d928a20-8838-4a81-b8b9-d213de5ea435"
},
"source": [
"models = [\n",
" dict(model='mobilenetv3_100'),\n",
" dict(model='dpn68b'),\n",
" dict(model='gluon_resnet50_v1d'),\n",
" dict(model='efficientnet_b2'),\n",
" dict(model='gluon_seresnext50_32x4d'),\n",
" dict(model='dpn92'),\n",
" dict(model='gluon_seresnext101_32x4d'),\n",
" dict(model='inception_resnet_v2'),\n",
" dict(model='pnasnet5large'),\n",
" dict(model='tf_efficientnet_b5'),\n",
" dict(model='ig_resnext101_32x8d'),\n",
" dict(model='ig_resnext101_32x16d'),\n",
" dict(model='ig_resnext101_32x32d'),\n",
" dict(model='ig_resnext101_32x48d'),\n",
"]\n",
"\n",
"results = OrderedDict()\n",
"for ma in models:\n",
" mk, mr = runner(ma, dataset, device)\n",
" results[mk] = mr\n",
"\n",
"results_df = pd.DataFrame.from_dict(results, orient='index')\n",
"results_df.to_csv('./cached-results.csv')"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth\" to /root/.cache/torch/checkpoints/mobilenetv3_100-35495452.pth\n",
"100%|██████████| 22064048/22064048 [00:00<00:00, 51730902.08it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 2.193 (2.193, 58.355/s) \tPrec@1 71.094 (71.094)\tPrec@5 91.406 (91.406)\n",
"Test: [20/79]\tTime 0.086 (0.750, 170.707/s) \tPrec@1 59.375 (67.857)\tPrec@5 83.594 (87.537)\n",
"Test: [40/79]\tTime 0.087 (0.728, 175.912/s) \tPrec@1 51.562 (67.264)\tPrec@5 78.125 (87.043)\n",
"Test: [60/79]\tTime 0.088 (0.717, 178.619/s) \tPrec@1 55.469 (64.460)\tPrec@5 77.344 (85.131)\n",
" * Prec@1 63.220 (36.780) Prec@5 84.500 (15.500)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth\" to /root/.cache/torch/checkpoints/dpn68b_extra-84854c156.pth\n",
"100%|██████████| 50765517/50765517 [00:00<00:00, 67204620.18it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.48627450980392156, 0.4588235294117647, 0.40784313725490196)\n",
"\tstd: (0.23482446870963955, 0.23482446870963955, 0.23482446870963955)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 4.517 (4.517, 28.334/s) \tPrec@1 76.562 (76.562)\tPrec@5 95.312 (95.312)\n",
"Test: [20/79]\tTime 0.353 (0.806, 158.907/s) \tPrec@1 55.469 (70.126)\tPrec@5 86.719 (89.137)\n",
"Test: [40/79]\tTime 0.449 (0.771, 165.921/s) \tPrec@1 58.594 (69.531)\tPrec@5 77.344 (88.415)\n",
"Test: [60/79]\tTime 1.137 (0.759, 168.550/s) \tPrec@1 60.156 (66.829)\tPrec@5 78.906 (86.578)\n",
" * Prec@1 65.650 (34.350) Prec@5 85.930 (14.070)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth\" to /root/.cache/torch/checkpoints/gluon_resnet50_v1d-818a1b1b.pth\n",
"100%|██████████| 102573346/102573346 [00:01<00:00, 65197850.65it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 4.053 (4.053, 31.584/s) \tPrec@1 79.688 (79.688)\tPrec@5 93.750 (93.750)\n",
"Test: [20/79]\tTime 0.195 (0.796, 160.803/s) \tPrec@1 67.969 (73.251)\tPrec@5 88.281 (90.216)\n",
"Test: [40/79]\tTime 0.201 (0.763, 167.796/s) \tPrec@1 60.156 (72.142)\tPrec@5 81.250 (89.520)\n",
"Test: [60/79]\tTime 0.200 (0.749, 170.872/s) \tPrec@1 62.500 (69.173)\tPrec@5 82.812 (87.795)\n",
" * Prec@1 67.920 (32.080) Prec@5 87.140 (12.860)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth\" to /root/.cache/torch/checkpoints/efficientnet_b2-cf78dc4d.pth\n",
"100%|██████████| 36788101/36788101 [00:00<00:00, 55440272.43it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 260, 260)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.89\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 3.771 (3.771, 33.946/s) \tPrec@1 78.906 (78.906)\tPrec@5 96.094 (96.094)\n",
"Test: [20/79]\tTime 0.495 (0.870, 147.210/s) \tPrec@1 67.969 (72.917)\tPrec@5 88.281 (91.071)\n",
"Test: [40/79]\tTime 0.308 (0.835, 153.252/s) \tPrec@1 58.594 (71.970)\tPrec@5 82.031 (90.473)\n",
"Test: [60/79]\tTime 0.959 (0.831, 154.056/s) \tPrec@1 64.062 (69.352)\tPrec@5 85.938 (88.909)\n",
" * Prec@1 67.780 (32.220) Prec@5 88.210 (11.790)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth\" to /root/.cache/torch/checkpoints/gluon_seresnext50_32x4d-90cf2d6e.pth\n",
"100%|██████████| 110578827/110578827 [00:01<00:00, 70807032.63it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 4.138 (4.138, 30.933/s) \tPrec@1 81.250 (81.250)\tPrec@5 94.531 (94.531)\n",
"Test: [20/79]\tTime 0.944 (0.819, 156.361/s) \tPrec@1 70.312 (74.144)\tPrec@5 88.281 (91.071)\n",
"Test: [40/79]\tTime 1.192 (0.793, 161.325/s) \tPrec@1 60.938 (72.847)\tPrec@5 82.812 (90.415)\n",
"Test: [60/79]\tTime 1.084 (0.782, 163.666/s) \tPrec@1 64.062 (69.980)\tPrec@5 84.375 (88.806)\n",
" * Prec@1 68.620 (31.380) Prec@5 88.340 (11.660)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth\" to /root/.cache/torch/checkpoints/dpn92_extra-b040e4a9b.pth\n",
"100%|██████████| 151248422/151248422 [00:01<00:00, 83488116.01it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.48627450980392156, 0.4588235294117647, 0.40784313725490196)\n",
"\tstd: (0.23482446870963955, 0.23482446870963955, 0.23482446870963955)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 7.253 (7.253, 17.648/s) \tPrec@1 77.344 (77.344)\tPrec@5 95.312 (95.312)\n",
"Test: [20/79]\tTime 0.494 (1.027, 124.688/s) \tPrec@1 66.406 (73.214)\tPrec@5 87.500 (90.662)\n",
"Test: [40/79]\tTime 0.486 (0.923, 138.660/s) \tPrec@1 56.250 (72.142)\tPrec@5 83.594 (89.863)\n",
"Test: [60/79]\tTime 0.502 (0.882, 145.078/s) \tPrec@1 63.281 (69.262)\tPrec@5 83.594 (88.089)\n",
" * Prec@1 67.960 (32.040) Prec@5 87.510 (12.490)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth\" to /root/.cache/torch/checkpoints/gluon_seresnext101_32x4d-cf52900d.pth\n",
"100%|██████████| 196505510/196505510 [00:02<00:00, 82370287.05it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 2.038 (2.038, 62.806/s) \tPrec@1 79.688 (79.688)\tPrec@5 95.312 (95.312)\n",
"Test: [20/79]\tTime 0.546 (0.909, 140.890/s) \tPrec@1 72.656 (75.521)\tPrec@5 88.281 (91.667)\n",
"Test: [40/79]\tTime 0.538 (0.867, 147.668/s) \tPrec@1 64.062 (74.409)\tPrec@5 83.594 (91.254)\n",
"Test: [60/79]\tTime 0.553 (0.845, 151.397/s) \tPrec@1 67.188 (71.760)\tPrec@5 89.062 (89.664)\n",
" * Prec@1 70.010 (29.990) Prec@5 88.920 (11.080)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth\" to /root/.cache/torch/checkpoints/inception_resnet_v2-940b1cd6.pth\n",
"100%|██████████| 223774238/223774238 [00:03<00:00, 66800834.91it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 299, 299)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.5, 0.5, 0.5)\n",
"\tstd: (0.5, 0.5, 0.5)\n",
"\tcrop_pct: 0.8975\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 6.944 (6.944, 18.434/s) \tPrec@1 77.344 (77.344)\tPrec@5 94.531 (94.531)\n",
"Test: [20/79]\tTime 1.212 (1.125, 113.791/s) \tPrec@1 69.531 (74.479)\tPrec@5 90.625 (91.704)\n",
"Test: [40/79]\tTime 1.213 (1.014, 126.269/s) \tPrec@1 64.062 (73.857)\tPrec@5 85.156 (90.892)\n",
"Test: [60/79]\tTime 0.950 (0.978, 130.937/s) \tPrec@1 71.094 (71.593)\tPrec@5 85.156 (89.267)\n",
" * Prec@1 70.100 (29.900) Prec@5 88.700 (11.300)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth\" to /root/.cache/torch/checkpoints/pnasnet5large-bf079911.pth\n",
"100%|██████████| 345153926/345153926 [00:04<00:00, 69633749.17it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 331, 331)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.5, 0.5, 0.5)\n",
"\tstd: (0.5, 0.5, 0.5)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 10.254 (10.254, 12.483/s) \tPrec@1 82.812 (82.812)\tPrec@5 97.656 (97.656)\n",
"Test: [20/79]\tTime 2.842 (3.130, 40.889/s) \tPrec@1 71.094 (77.939)\tPrec@5 90.625 (93.341)\n",
"Test: [40/79]\tTime 2.866 (2.991, 42.795/s) \tPrec@1 67.969 (76.467)\tPrec@5 87.500 (92.397)\n",
"Test: [60/79]\tTime 2.843 (2.944, 43.477/s) \tPrec@1 73.438 (74.027)\tPrec@5 88.281 (90.779)\n",
" * Prec@1 72.410 (27.590) Prec@5 90.250 (9.750)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth\" to /root/.cache/torch/checkpoints/tf_efficientnet_b5-c6949ce9.pth\n",
"100%|██████████| 122398414/122398414 [00:02<00:00, 61095444.04it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 456, 456)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.934\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 11.010 (11.010, 11.626/s) \tPrec@1 81.250 (81.250)\tPrec@5 96.875 (96.875)\n",
"Test: [20/79]\tTime 2.901 (3.309, 38.677/s) \tPrec@1 70.312 (77.418)\tPrec@5 92.188 (93.415)\n",
"Test: [40/79]\tTime 2.892 (3.107, 41.197/s) \tPrec@1 62.500 (76.239)\tPrec@5 89.844 (92.950)\n",
"Test: [60/79]\tTime 2.908 (3.041, 42.085/s) \tPrec@1 75.000 (73.770)\tPrec@5 88.281 (91.624)\n",
" * Prec@1 72.550 (27.450) Prec@5 91.100 (8.900)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x8-c38310e5.pth\n",
"100%|██████████| 356056638/356056638 [00:09<00:00, 38784641.11it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bilinear\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 5.765 (5.765, 22.204/s) \tPrec@1 80.469 (80.469)\tPrec@5 96.875 (96.875)\n",
"Test: [20/79]\tTime 0.855 (1.085, 117.918/s) \tPrec@1 75.781 (78.832)\tPrec@5 93.750 (94.271)\n",
"Test: [40/79]\tTime 0.853 (0.995, 128.593/s) \tPrec@1 66.406 (77.896)\tPrec@5 88.281 (93.807)\n",
"Test: [60/79]\tTime 0.833 (0.961, 133.263/s) \tPrec@1 74.219 (75.000)\tPrec@5 90.625 (92.623)\n",
" * Prec@1 73.780 (26.220) Prec@5 92.260 (7.740)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x16-c6f796b0.pth\n",
"100%|██████████| 777518664/777518664 [00:15<00:00, 50031408.07it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bilinear\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 11.569 (11.569, 11.064/s) \tPrec@1 84.375 (84.375)\tPrec@5 99.219 (99.219)\n",
"Test: [20/79]\tTime 1.649 (2.129, 60.119/s) \tPrec@1 76.562 (80.990)\tPrec@5 95.312 (95.164)\n",
"Test: [40/79]\tTime 1.620 (1.884, 67.948/s) \tPrec@1 67.969 (79.630)\tPrec@5 89.844 (94.722)\n",
"Test: [60/79]\tTime 1.637 (1.798, 71.191/s) \tPrec@1 75.781 (77.267)\tPrec@5 90.625 (93.545)\n",
" * Prec@1 76.020 (23.980) Prec@5 93.070 (6.930)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x32-e4b90b00.pth\n",
"100%|██████████| 1876573776/1876573776 [01:36<00:00, 19485230.33it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bilinear\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 19.815 (19.815, 6.460/s) \tPrec@1 86.719 (86.719)\tPrec@5 99.219 (99.219)\n",
"Test: [20/79]\tTime 3.245 (3.981, 32.154/s) \tPrec@1 77.344 (81.436)\tPrec@5 94.531 (95.238)\n",
"Test: [40/79]\tTime 3.405 (3.668, 34.901/s) \tPrec@1 68.750 (80.526)\tPrec@5 89.844 (94.684)\n",
"Test: [60/79]\tTime 3.437 (3.592, 35.638/s) \tPrec@1 79.688 (78.279)\tPrec@5 91.406 (93.699)\n",
" * Prec@1 77.020 (22.980) Prec@5 93.370 (6.630)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x48-3e41cc8a.pth\n",
"100%|██████████| 3317136976/3317136976 [01:15<00:00, 43847655.19it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bilinear\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 34.840 (34.840, 3.674/s) \tPrec@1 88.281 (88.281)\tPrec@5 100.000 (100.000)\n",
"Test: [20/79]\tTime 5.808 (7.029, 18.209/s) \tPrec@1 78.906 (81.696)\tPrec@5 95.312 (95.722)\n",
"Test: [40/79]\tTime 5.890 (6.465, 19.800/s) \tPrec@1 67.969 (80.736)\tPrec@5 89.062 (94.989)\n",
"Test: [60/79]\tTime 5.872 (6.274, 20.401/s) \tPrec@1 75.000 (78.548)\tPrec@5 92.188 (93.942)\n",
" * Prec@1 77.280 (22.720) Prec@5 93.610 (6.390)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xA2_6Jl9PlU5",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.figsize'] = [16, 10]\n",
"\n",
"names_all = list(results.keys())\n",
"top1_all = np.array([results[m]['top1'] for m in names_all])\n",
"top1_sort_ix = np.argsort(top1_all)\n",
"top1_sorted = top1_all[top1_sort_ix]\n",
"top1_names_sorted = np.array(names_all)[top1_sort_ix]\n",
"\n",
"top5_all = np.array([results[m]['top5'] for m in names_all])\n",
"top5_sort_ix = np.argsort(top5_all)\n",
"top5_sorted = top5_all[top5_sort_ix]\n",
"top5_names_sorted = np.array(names_all)[top5_sort_ix]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "elJkzmkHo2UT",
"colab_type": "text"
},
"source": [
"# Results\n",
"\n",
"We'll walk through the results in a few charts and text dumps...\n",
"\n",
"1. Top-1 Accuracy by Model\n",
"2. Top-1 Accuracy Difference Between ImageNet-1k and ImageNet-V2\n",
"3. Top-5 Accuracy Difference Between ImageNet-1k and ImageNet-V2\n",
"4. A Text Comparison of Absolute and Relative Differences"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7iKh048WpWmh",
"colab_type": "text"
},
"source": [
"# Top-1 Accuracy by Model\n",
"\n",
"The Instagram pretrained ResNeXts push past the mid 70s Top-1 which is great for this test set. If you're familiar with normal ImageNet-1k validation scores, you'll notice they are all quite a bit lower, we'll analyse the differences in the next two charts."
]
},
{
"cell_type": "code",
"metadata": {
"id": "EoyBP9GIpV7P",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 876
},
"outputId": "cb8c7272-9d65-416a-c20f-38110637895c"
},
"source": [
"fig = plt.figure()\n",
"ax1 = fig.add_subplot(111)\n",
"ax1.barh(top1_names_sorted, top1_sorted, color='lightcoral')\n",
"\n",
"ax1.set_title('Top-1 by Model')\n",
"ax1.set_xlabel('Top-1 Accuracy (%)')\n",
"ax1.set_yticklabels(top1_names_sorted)\n",
"ax1.autoscale(True, axis='both')\n",
"\n",
"acc_min = top1_sorted[0]\n",
"acc_max = top1_sorted[-1]\n",
"plt.xlim([math.ceil(acc_min - .3*(acc_max - acc_min)), math.ceil(acc_max)])\n",
"\n",
"plt.vlines(plt.xticks()[0], *plt.ylim(), color='0.5', alpha=0.2, linestyle='--')\n",
"plt.show()\n",
"\n",
"print('Results by top-1 accuracy:')\n",
"results_by_top1 = list(sorted(results.keys(), key=lambda x: results[x]['top1'], reverse=True))\n",
"for m in results_by_top1:\n",
" print(' Model: {:30} Top-1 {:4.2f}, Top-5 {:4.2f}'.format(m, results[m]['top1'], results[m]['top5']))"
],
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABDgAAAJcCAYAAAAcpRT5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xu4XVV97//3JxAgEkjCRYR6SX94\nQSsQNVUr2lIqitp651i1VlpQPD324qXVqgek9JzjrZ5qK1q03ilVKrRWrUo5eKCASoKAgFiOl1oU\nRAMJhIQAyff3xxo7LrY72TsLyBrLvF/Ps56115xzzPVd64PbZ38z5pipKiRJkiRJkibZvHEXIEmS\nJEmSdHfZ4JAkSZIkSRPPBockSZIkSZp4NjgkSZIkSdLEs8EhSZIkSZImng0OSZIkSZI08WxwSJIk\ndSDJl5P81rjr2JIkr0jyr3M89u+TvOnerkmSpGE2OCRJ0g4jydqhx6Yk64dev/gefq/dk5yZ5D+S\nVJLH35Pn38r7HtTe76Jp2w9IcmeSq7dHHZIkbW82OCRJ0g6jqhZOPYDvAb8xtO20e/rtgC8BLwRu\nuofPPZuNwL5JHjK07beAb23nOiRJ2m5scEiSJDVJFiR5T5Lrklyb5O1J5rd9RyX5f0lOSnJjku8k\nOXpL56qqdVX17qq6ENg0xxIelmRlkjVJPpVkUXvvc5K8bFqt30zytK2c6+PAbw+9fgnw0WnnODjJ\n+UlWJ7l8+HxJ7pvkc0lubrNBHjRt7COT/J8kNyX5RpJnz/EzSpJ0r7DBIUmS9BMnAYcABwOPAQ4H\n/mRo/1JgF+B+wMuAjyT5+Xvw/X8beDHwc+19/qJt/wiDGRgAJHkcsCfwxa2c66PAizPwaAZNlsuG\nzrEb8FngH4F9gT8Gzhj6PKcCNwL7Af8V+N2hsXsCZwN/C+zT6v5gkgeP9KklSboH2OCQJEn6iRcD\nJ1bVj6vqh8CfM5j5MOVO4KSqur2q/hX4V+D59+D7f6iqrq6qtcCJDC5vAfgU8KgkD2yvXwKcXlUb\nt3Siqvo28APglxk0ID467ZAnMbiM5p1VdUdVfYFB0+IFrfnxTOBNVbW+qi4Fhi/heQ5wRVWdVlUb\nq+pi4J+B543+0SVJuntscEiSJAFJwmBmxn8Mbf4PBrMppvyoqm6btv+AJA8dWqz0x3ejjP+cdu77\nJFlUVbcCZzKYkTEfeAHwsTmc76MMZl4czV0bFAAHAN+rqpr2nj/H4HvIDPVMeRDwy+3SltVJVjNo\nbuw/h5okSbpX2OCQJEkC2h/613PXtSYeCHx/6PU+bXbD8P4fVNW/Dy1Wus/dKOMB0869rqrWtNcf\nYTDD5Cjgh1X1tTmc75MMmhuXV9X10/b9oL3HsKnPez2D2R3T65nyn8AXq2rx0GNhVf3RHGqSJOle\nYYNDkiTpJ04HTkyyd5L7Am9ksFjnlPnAf0+yS5IjgCMZXD4yoyS7DjVEdpnWHJnJMW02yELgzcAn\nhvZ9CdgD+B/89OUmM6qq1QzWEfmvM+w+H5iX5I+S7JzkSOApwCfbLJV/Bk5qC68ewqC5MuUfGVwy\n84Ik89v38fgkD51LXZIk3RtscEiSJP3ECcBVwJXApcAFwNuG9n+XwToc1wMfBH6nrXWxJf8BrAf2\nBv4vsD7J/bZy/McYNFm+z2BR0NdM7WgzTD4G/AI/fbnJFlXVV6vquzNsvw34dQZriKwC3gm8YOjz\nHM9ggdEfAn8DfGho7E3AU4HfAa5jMBvkzxk0gCRJGovc9bJLSZIkzSTJUcBfV9XY7hSS5OXAf6mq\nJ4+rBkmSeuUMDkmSpAmQZHcGl5qcOu5aJEnqkQ0OSZKkziV5JnAD8P+AfxhzOZIkdclLVCRJkiRJ\n0sRzBockSZIkSZp4O4+7AGkU++yzTy1dunTcZUiSJEmS7gUrV678cVXtuy1jbHBoIi1dupQVK1aM\nuwxJkiRJ0r0gyX9s6xgvUZEkSZIkSRPPBockSZIkSZp4NjgkSZIkSdLEs8EhSZIkSZImng0OSZIk\nSZI08WxwSJIkSZKkiWeDQ5IkSZIkTTwbHJIkSZIkaeLZ4JAkSZIkSRPPBockSZIkSZp4NjgkSZIk\nSdLEs8EhSZIkSZImng0OSZIkSZI08WxwSJIkSZKkiWeDQ5IkSZIkTTwbHJIkSZIkaeLZ4JAkSZIk\nSRPPBockSZIkSZp4NjgkSZIkSdLEs8EhSZIkSZImng0OSZIkSZI08WxwSJIkSZKkiWeDQ5IkSZIk\nTTwbHJpId95557hLELBq1SpWrVo17jKEWfTGPPphFv0wi36YRT/Moi/m0Y9Rc9j5Hq5D2i523tn/\ndHuwZMmScZegxiz6Yh79MIt+mEU/zKIfZtEX8+jHqFn4V6Kkkc2b5ySwXphFX8yjH2bRD7Poh1n0\nwyz6Yh79GDULE9RE2rRp07hLELBu3TrWrVs37jKEWfTGPPphFv0wi36YRT/Moi/m0Y9Rc7DBoYlk\ng6MP69evZ/369eMuQ5hFb8yjH2bRD7Poh1n0wyz6Yh79GDUHGxySJEmSJGni2eCQJEmSJEkTzwaH\nJEmSJEmaeDY4JEmSJEnSxEtVjbsGaZstX768VqxYMe4ydnhTvz+SjLkSmUVfzKMfZtEPs+iHWfTD\nLPpiHv2oKubNm7eyqpZvy7id762CJP3s85d/P8yiL+bRD7Poh1n0wyz6YRZ9MY9+jJqFl6hoInmb\n2D7ceuut3HrrreMuQ5hFb8yjH2bRD7Poh1n0wyz6Yh79GDUHGxyaSDY4+nDbbbdx2223jbsMYRa9\nMY9+mEU/zKIfZtEPs+iLefRj1BxscEiSJEmSpIlng0OSJEmSJE08GxySJEmSJGni2eCQJEmSJEkT\nL1P3+pUmyfLly2vFihXjLkOSJEmSdC9IsrKqlm/LGGdwSJIkSZKkiWeDQxPJ28T2Ye3ataxdu3bc\nZQiz6I159MMs+mEW/TCLfphFX8yjH6PmYINDE8kGRx82bNjAhg0bxl2GMIvemEc/zKIfZtEPs+iH\nWfTFPPoxag42OCRJkiRJ0sTbedwFSKPYdMMNrDnppHGXscO7pT37i2T8zKIv5tEPs+iHWfTDLPph\nFn3Z0fNYdOKJ4y7hbnMGhyRJkiRJmng2OCRJkiRJ0sTbUWffaMLtNO4CBMCScRegzcyiL+bRD7Po\nh1n0wyz6YRZ9MY9+7L333iONcwaHJEmSJEmaeDY4NJG8SWwfbm0PjZ9Z9MU8+mEW/TCLfphFP8yi\nL+bRj1tuuWX2g2Zgg0MTqcZdgAC4vT00fmbRF/Poh1n0wyz6YRb9MIu+mEc/br99tCRscEiSJEmS\npIlng0OSJEmSJE08GxySJEmSJGnieZtYSSPLuAvQZmbRF/Poh1n0wyz6YRb9MIu+mEc/ktHSGPsM\njiQXjruGuyPJG6a9/mCSG5JcMW37XknOTnJNe17Sth+U5KIkG5K8dpb32i3JV5NcluTKJCcN7Tst\nyTeTXNFqmD/H+p+XpJIsb6/nJ/lIkq8n+UaSP93CuDfPVG+SByQ5N8lVrcY/HNr39iRXJ7k8yVlJ\nFk8b+8Aka2f7HgB2msuH071ucXto/MyiL+bRD7Poh1n0wyz6YRZ9MY9+7LXXXiONG3uDo6qecHfP\nkWScf+++YdrrDwNHzXDc64FzquohwDntNcCNwB8A75jDe20AjqiqQ4FlwFFJHt/2nQYcBBwMLACO\nm+1kSfYA/hD4ytDmo4Fdq+pg4DHA8UmWzqG2KXcCr6mqRwCPB/5bkke0fWcDj6yqQ4B/B6Y3T94J\n/Ms2vJckSZIkSUAHDY4ka9vzvCSntH/hPzvJ55I8fyvjvpvkrUkuAY5OcmCSzydZmeT8JAe1445u\nsxouS3Je23ZMkjPb8dckedvQeZ/SZlRckuSMJAuTLGqzIx7Wjjk9ycuSvAVYkOTSJKcBVNV5DJoW\n0z0L+Ej7+SPAs
"text/plain": [
"<Figure size 1152x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Results by top-1 accuracy:\n",
" Model: ig_resnext101_32x48d-224 Top-1 77.28, Top-5 93.61\n",
" Model: ig_resnext101_32x32d-224 Top-1 77.02, Top-5 93.37\n",
" Model: ig_resnext101_32x16d-224 Top-1 76.02, Top-5 93.07\n",
" Model: ig_resnext101_32x8d-224 Top-1 73.78, Top-5 92.26\n",
" Model: tf_efficientnet_b5-456 Top-1 72.55, Top-5 91.10\n",
" Model: pnasnet5large-331 Top-1 72.41, Top-5 90.25\n",
" Model: inception_resnet_v2-299 Top-1 70.10, Top-5 88.70\n",
" Model: gluon_seresnext101_32x4d-224 Top-1 70.01, Top-5 88.92\n",
" Model: gluon_seresnext50_32x4d-224 Top-1 68.62, Top-5 88.34\n",
" Model: dpn92-224 Top-1 67.96, Top-5 87.51\n",
" Model: gluon_resnet50_v1d-224 Top-1 67.92, Top-5 87.14\n",
" Model: efficientnet_b2-260 Top-1 67.78, Top-5 88.21\n",
" Model: dpn68b-224 Top-1 65.65, Top-5 85.93\n",
" Model: mobilenetv3_100-224 Top-1 63.22, Top-5 84.50\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NT8YgacDQ_Bb",
"colab_type": "code",
"colab": {}
},
"source": [
"!wget -q https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-all.csv\n",
"original_df = pd.read_csv('./results-all.csv', index_col=0)\n",
"original_results = original_df.to_dict(orient='index')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vXpJrzVnRAc4",
"colab_type": "code",
"colab": {}
},
"source": [
"# helper methods for dumbbell plot\n",
"import matplotlib.lines as mlines\n",
"\n",
"def label_line_horiz(ax, line, label, color='0.5', fs=14, halign='center'):\n",
" xdata, ydata = line.get_data()\n",
" x1, x2 = xdata\n",
" xx = 0.5 * (x1 + x2)\n",
" text = ax.annotate(\n",
" label, xy=(xx, ydata[0]), xytext=(0, 1), textcoords='offset points',\n",
" size=fs, color=color, zorder=3,\n",
" bbox=dict(boxstyle=\"round\", fc=\"w\", color='0.5'),\n",
" horizontalalignment='center',\n",
" verticalalignment='center')\n",
" return text\n",
"\n",
"def draw_line_horiz(ax, p1, p2, label, color='black'):\n",
" l = mlines.Line2D(*zip(p1, p2), color=color, zorder=0)\n",
" ax.add_line(l)\n",
" label_line(ax, l, label)\n",
" return l\n",
"\n",
"def label_line_vert(ax, line, label, color='0.5', fs=14, halign='center'):\n",
" xdata, ydata = line.get_data()\n",
" y1, y2 = ydata\n",
" yy = 0.5 * (y1 + y2)\n",
" text = ax.annotate(\n",
" label, xy=(xdata[0], yy), xytext=(0, 0), textcoords='offset points',\n",
" size=fs, color=color, zorder=3,\n",
" bbox=dict(boxstyle=\"round\", fc=\"w\", color='0.5'),\n",
" horizontalalignment='center',\n",
" verticalalignment='center')\n",
" return text\n",
"\n",
"def draw_line_vert(ax, p1, p2, label, color='black'):\n",
" l = mlines.Line2D(*zip(p1, p2), color=color, zorder=0)\n",
" ax.add_line(l)\n",
" label_line_vert(ax, l, label)\n",
" return l\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tkADTg3_plVy",
"colab_type": "text"
},
"source": [
"# Top-1 Accuracy Difference Between ImageNet-1k and ImageNet-V2\n",
"\n",
"And here we are, the focal point. How does each model's ImageNet-V2 accuracy compare with its original ImageNet-1k score? \n",
"\n",
"The general trend -- with increased model capacity scores on both sets increase and the gap (generally) narrows. This matches results of the [original paper for the ImageNet-V2](https://arxiv.org/abs/1902.10811). \n",
"\n",
"Most noteably though, the WSL Instagram ResNeXt101 models are the only with performance gaps less than 10%. I've tested quite a few models on this dataset, more than in this notebook. This is the first time I've run into any models with absolute performance this high and performance gaps this low. Impressive. I hope to explore what this means for transfer learning and adaptation of these models to other tasks."
]
},
{
"cell_type": "code",
"metadata": {
"id": "PXuGz5uEuNJc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "a4b3331d-e501-4bc6-e4d0-c5a9f880b064"
},
"source": [
"fig = plt.figure()\n",
"ax1 = fig.add_subplot(111)\n",
"\n",
"# draw the ImageNet-V2 dots, we're sorted on this\n",
"ax1.scatter(x=top1_names_sorted, y=top1_sorted, s=64, c='lightcoral',marker=\"o\", label='ImageNet-V2 Matched-Freq')\n",
"\n",
"# draw the original ImageNet-1k validation dots\n",
"orig_top1 = [original_results[results[n]['model']]['top1'] for n in top1_names_sorted]\n",
"ax1.scatter(x=top1_names_sorted, y=orig_top1, s=64, c='steelblue', marker=\"o\", label='ImageNet-1K')\n",
"\n",
"for n, vo, vn in zip(top1_names_sorted, orig_top1, top1_sorted):\n",
" draw_line_vert(ax1, (n, vo), (n, vn),\n",
" str(round(vo - vn, 2)), 'skyblue')\n",
"\n",
"ax1.set_title('Top-1 Difference')\n",
"ax1.set_ylabel('Top-1 Accuracy (%)')\n",
"ax1.set_xlabel('Model')\n",
"yl, yh = ax1.get_ylim()\n",
"yl = 5 * ((yl + 1) // 5 + 1) \n",
"yh = 5 * (yh // 5 + 1)\n",
"for y in plt.yticks()[0][1:-1]:\n",
" ax1.axhline(y, 0.02, 0.98, c='0.5', alpha=0.2, linestyle='-.')\n",
"ax1.set_xticklabels(top1_names_sorted, rotation='-30', ha='left')\n",
"ax1.legend(loc='upper left')\n",
"plt.show()\n",
"\n",
"print('Results by absolute accuracy gap between ImageNet-Sketch and original ImageNet top-1:')\n",
"gaps = {x: (results[x]['top1'] - original_results[results[x]['model']]['top1']) for x in results.keys()}\n",
"sorted_keys = list(sorted(results.keys(), key=lambda x: gaps[x], reverse=True))\n",
"for m in sorted_keys:\n",
" print(' Model: {:30} {:4.2f}%'.format(m, gaps[m]))\n",
"print()\n",
"\n",
"print('Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-1:')\n",
"gaps = {x: 100 * (results[x]['top1'] - original_results[results[x]['model']]['top1']) / original_results[results[x]['model']]['top1'] for x in results.keys()}\n",
"sorted_keys = list(sorted(results.keys(), key=lambda x: gaps[x], reverse=True))\n",
"for m in sorted_keys:\n",
" print(' Model: {:30} {:4.2f}%'.format(m, gaps[m]))\n",
"print()"
],
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/sAAAKsCAYAAACzhK3BAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xt0VPW9///nZzIzIZOEJBDulyAB\nIncIFyERoQICCkqrAa0HVMRLPdbL6fF8e9b3a7X91uNZrW2P9fhrvRyXN1owWLHwBUUqaEmgQCCA\nARQDJiCBhiTkQi6TyezfH0mmhFwIIZMJk9djLZfsvT+fz37vZCeZ996fi7EsCxEREREREREJHrZA\nByAiIiIiIiIi7UvJvoiIiIiIiEiQUbIvIiIiIiIiEmSU7IuIiIiIiIgEGSX7IiIiIiIiIkFGyb6I\niIiIiIhIkFGyLyIi0kkYY35qjPnvC7aXGmO+NcaUGWNGGmPGGGMO1G0/GMhYRUREpHMzlmUFOgYR\nEZF2YYwpu2DTBVQBNXXbD1mWtaodzxUOvANMAgYD0y3L2tlC+Z3ABKAasIAvgTXAi5ZlVTdT51tg\nhWVZH9dtrwJyLcv69/a6DhEREQlOerMvIiJBw7KsiPr/gFxg0QX72i3Rrz8dsA24CyhqZZ2VlmVF\nAv2BfwfuBT5sqqAxxlFXLuuC3XEXbbeaMcbelnoiIiJydVKyLyIiXYYxJswY87IxJs8Yc9IY88u6\npBpjzHxjzNd1XekLjTHHjTEpzbVlWVa5ZVm/tSwrHfBeThyWZZVZlrUFuA2YbYyZUxfDfxpjXjfG\nRPGPBwhfGmOyjDHpwHTg9bpu/IPrrue/jDEnjDGnjTEvGWNCL7qep40xZ4Df1e3/bt1QgHPGmL8a\nY0Zd8PU5bYx50hjzhTGm2BizyhjjvOB4Sl3dUmPMUWPM7Lr9PYwxb9fVP2GMecYYo88YIiIiAaQ/\nxCIi0pX8FBgHjKW2+/0s4N8uOD4EcAJ9gQeAt4wx1/grGMuysoH9wIyL9hcDsXWbCZZljbYsKwnY\nTW3vgAjLsnKBXwMD664nARgB/Pii63EAg4DHjDHTgP8PuA/oSe0whHUXvfW/A5gNDAOuA74PYIy5\nAXgVeByIqitzoq7OKqAYGApMBRYDy9r6dREREZErp2RfRES6kruBZyzLOmtZ1hng5zRMSj3ATy3L\ncte9ed9CbfLrT6eAHpdbqS5Bvx943LKsc3UPCP4TuPOCYlXA/627ngrgIeC/LcvKsCyrxrKsV4FQ\nah981PuNZVlnLMvKBzZSO88Adef6vWVZWy3L8lqWlWtZ1lfGmDjgBuBf6no75AG/vSgOERER6WAa\nvyciIl2CMcZQ+8Y+54LdOcCAC7bzLcuqvOh4f2PMCGBv3b5Ky7JiaT8DqH27f7n6U/vWPqv20gAw\n1D6wqHf6osn/4oAlxpinLtjnpOHX4PQF/y7nHz0MBgF/bSKOOKAbkH9BHDbg61ZfiYiIiLQ7Jfsi\nItIlWJZlGWNOU5ucZtftHgx8e0GxWGNMtwsS/sHAdsuyvgIi2jsmY8xQaocVtGV2/TxqE/t4y7IK\nmilz8ZI7J4D/Z1nWr9pwvhNAfDP7y4AYS0v8iIiIdBrqxi8iIl3JH4FnjDE9jTG9gf8NvHvBcQfw\ntDHGaYy5EZgLvN9cY8aYUGNMt7pN5wX/bpExJryu/XXAtrohA5el7o39G8CLxphYU2uQMWZuC9Ve\nBX5ojJlcVz7CGHOrMcbVilO+DjxkjLnBGGOrO9cIy7KOAzuBXxhjIuuODTfGXH+51yQiIiLtR8m+\niIh0JT8BDlG7fF0mkAb84oLj31D7tvw0tYn0fZZlHWuhvRyggtrJ7j4DKowxfVso/7oxprSu/V9S\nO7HdojZdSa0nqB3zv4faCfI+onZivSZZlpUGPAa8ApwDvqJ2Ar5LvpG3LOuvwMPUTvBXDPyF2skB\noXb5wWjgCFAIrAH6tOWCREREpH0Y9bgTERGpXaqO2snrmk2WRURERK4WerMvIiIiIiIiEmSU7IuI\niIiIiIgEGXXjFxEREREREQkyerMvIiIiIiIiEmSU7IuIiIiIiIgEGXugA2iN2NhYa8iQIYEOQ0RE\nRERERPwgIyPjrGVZvQIdRzC5KpL9IUOGsGfPnkCHISIiIiIiIn5gjMkJdAzBRt34RURERERERIKM\nkn0RERERERGRIKNkX0RERERERCTIXBVj9ptSXV3NyZMnqaysDHQoIq3SrVs3Bg4ciMPhCHQoIiIi\nIiIS5K7aZP/kyZNERkYyZMgQjDGBDkekRZZlUVBQwMmTJ7nmmmsCHY6IiIiIiAS5q7Ybf2VlJT17\n9lSiL1cFYww9e/ZUTxQREREREekQV22yDyjRl6uK7lcREREREekoV3WyH2gREREBO/eQIUO4/fbb\nfdtr167l3nvvbbFOZmYmGzdubPLYT3/6U/793/+9UfmRI0dSXl7OLbfcwrXXXsvo0aP58Y9/3GQb\nb775JsYYtmzZ4tu3bt06jDGsXbu2xdjefPNNTp06dckyjz76aItlWrJt2zYWLlzY5LFZs2aRkJDA\nhAkTmDBhwiXjFRERERER6cy6RLJvud1Ubt1KyS9/SfFPf0rJL39J5datWG53oEO7IhkZGRw6dKjV\n5VtK9u+66y7WrFnTYN/q1au56667APjXf/1Xjhw5wr59+0hLS2PTpk1NtjN27FhWr17t2/7jH//I\n+PHjLxlba5J9f1u1ahWZmZlkZmZyxx13NDhmWRZerzdAkYmIiIiIiFyeoE/2Lbebstdfpyo9Hau8\nvHZfeTlV6emUvf56uyT827ZtY+bMmdx2220MHTqUH//4x6xatYqpU6cyduxYsrOzAVi/fj3XXXcd\nEydOZM6cOZw5cwaA/Px85s6dy+jRo1m5ciVxcXGcPXsWgHfffZepU6cyYcIEHnroIWpqanzn/dGP\nfsRzzz3XKJ7z58+zYsUKpk6dysSJE/nwww9xu9385Cc/Yc2aNUyYMKFRYj9ixAhiYmL429/+5tv3\n3nvvcdddd+FyufjOd74DgNPpJDExkZMnTzb5tZgxYwa7du2iurqasrIyvv76ayZMmOA7/rOf/Ywp\nU6YwZswYHnzwQSzLYu3atezZs4e7776bCRMmUFFRwe7du0lKSmL8+PFMnTqV0tJSAE6dOsX8+fMZ\nPnw4//Zv/+Zrd/PmzUyfPp3ExERSUlIoKysD4KOPPuLaa68lMTGRP/3pT638jtb65ptvSEhIYPny\n5YwZM4YTJ0606jyPPfZYsz0IREREREREOkLQJ/tVaWl4i4rA42l4wOPBW1REVVpau5xn//79/P73\nv+fw4cO88847fPXVV+zatYuVK1fy0ksvAXD99dezc+dO9u3bx5133skvfvELoLYL/Y033khWVhZ3\n3HEHubm5ABw+fJg1a9aQlpZGZmYmISEhrFq1ynfOJUuWsHfvXr7++usGsTz33HPceOON7Nq1i61b\nt/LUU09RXV3Nz372M5YuXUpmZiZLly5tdA133XWX7638zp076dGjB8OHD29Q5ty5c6xfv57Zs2c3\n+XUwxjBnzhw+/vhjPvzwQ2699dYGxx999FF2797NF198QUVFBRs2bOCOO+5g8uTJvjfrISEhLF26\nlBdffJH9+/ezZcsWwsLCgNreCWvWrOHgwYOsWbOGEydOcPbsWX7+85+zZcsW9u7dy+TJk/n1r39N\nZWUlDzzwAOvXrycjI4PTp0+3+D2sf9gwYcIECgoKADh69CiPPPIIWVlZhIeHt8t5RERERERE/O2q\nXXqvtdx79jRO9Ot5PLj37KFb3VvrKzFlyhT69esHQHx8PDfddBNQ261969atQO1ygUuXLiUvLw+3\n2+1bgm379u188MEHAMyfP5+YmBgA/vKXv5CRkcGUKVMAqKiooHfv3r5zhoSE8NRTT/H888+zYMEC\n3/7Nmzfz5z//m
"text/plain": [
"<Figure size 1152x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Results by absolute accuracy gap between ImageNet-Sketch and original ImageNet top-1:\n",
" Model: ig_resnext101_32x32d-224 -8.07%\n",
" Model: ig_resnext101_32x16d-224 -8.16%\n",
" Model: ig_resnext101_32x48d-224 -8.16%\n",
" Model: ig_resnext101_32x8d-224 -8.91%\n",
" Model: pnasnet5large-331 -10.33%\n",
" Model: inception_resnet_v2-299 -10.36%\n",
" Model: tf_efficientnet_b5-456 -10.63%\n",
" Model: gluon_seresnext101_32x4d-224 -10.89%\n",
" Model: gluon_resnet50_v1d-224 -11.15%\n",
" Model: gluon_seresnext50_32x4d-224 -11.29%\n",
" Model: dpn68b-224 -11.86%\n",
" Model: efficientnet_b2-260 -11.97%\n",
" Model: dpn92-224 -12.06%\n",
" Model: mobilenetv3_100-224 -12.41%\n",
"\n",
"Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-1:\n",
" Model: ig_resnext101_32x32d-224 -9.49%\n",
" Model: ig_resnext101_32x48d-224 -9.55%\n",
" Model: ig_resnext101_32x16d-224 -9.69%\n",
" Model: ig_resnext101_32x8d-224 -10.77%\n",
" Model: pnasnet5large-331 -12.48%\n",
" Model: tf_efficientnet_b5-456 -12.78%\n",
" Model: inception_resnet_v2-299 -12.88%\n",
" Model: gluon_seresnext101_32x4d-224 -13.46%\n",
" Model: gluon_resnet50_v1d-224 -14.11%\n",
" Model: gluon_seresnext50_32x4d-224 -14.13%\n",
" Model: efficientnet_b2-260 -15.01%\n",
" Model: dpn92-224 -15.07%\n",
" Model: dpn68b-224 -15.31%\n",
" Model: mobilenetv3_100-224 -16.41%\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qjHwbvYLpwLH",
"colab_type": "text"
},
"source": [
"# Top-5 Accuracy Difference Between ImageNet-1k and ImageNet-V2\n",
"\n",
"Top-5 differences very similar to the Top-1 above. The same overall trend and the same stand-out performance for the IG ResNeXts."
]
},
{
"cell_type": "code",
"metadata": {
"id": "pFFXDpl9Jclm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "01bb350c-b625-434b-9f9c-f6703b28d0ae"
},
"source": [
"fig = plt.figure()\n",
"ax1 = fig.add_subplot(111)\n",
"\n",
"# draw the ImageNet-V2 top-5 dots, we're sorted on this\n",
"ax1.scatter(x=top5_names_sorted, y=top5_sorted, s=64, c='lightcoral',marker=\"o\", label='ImageNet-V2 Matched-Freq')\n",
"\n",
"# draw the original ImageNet-1k validation dots\n",
"orig_top5 = [original_results[results[n]['model']]['top5'] for n in top5_names_sorted]\n",
"ax1.scatter(x=top5_names_sorted, y=orig_top5, s=64, c='steelblue', marker=\"o\", label='ImageNet-1K')\n",
"\n",
"for n, vo, vn in zip(top5_names_sorted, orig_top5, top5_sorted):\n",
" draw_line_vert(ax1, (n, vo), (n, vn),\n",
" str(round(vo - vn, 2)), 'skyblue')\n",
"\n",
"ax1.set_title('Top-5 Difference')\n",
"ax1.set_ylabel('Top-5 Accuracy (%)')\n",
"ax1.set_xlabel('Model')\n",
"yl, yh = ax1.get_ylim()\n",
"yl = 5 * ((yl + 1) // 5 + 1) \n",
"yh = 5 * (yh // 5 + 1)\n",
"for y in plt.yticks()[0][2:-2]:\n",
" ax1.axhline(y, 0.02, 0.98, c='0.5', alpha=0.2, linestyle='-.')\n",
"ax1.set_xticklabels(top5_names_sorted, rotation='-30', ha='left')\n",
"ax1.legend(loc='upper left')\n",
"plt.show()\n",
"\n",
"print('Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-5:')\n",
"gaps = {x: (results[x]['top5'] - original_results[results[x]['model']]['top5']) for x in results.keys()}\n",
"sorted_keys = list(sorted(results.keys(), key=lambda x: gaps[x], reverse=True))\n",
"for m in sorted_keys:\n",
" print(' Model: {:30} {:4.2f}%'.format(m, gaps[m]))\n",
"print()\n",
"\n",
"print('Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-5:')\n",
"gaps = {x: 100 * (results[x]['top5'] - original_results[results[x]['model']]['top5']) / original_results[results[x]['model']]['top5'] for x in results.keys()}\n",
"sorted_keys = list(sorted(results.keys(), key=lambda x: gaps[x], reverse=True))\n",
"for m in sorted_keys:\n",
" print(' Model: {:30} {:4.2f}%'.format(m, gaps[m]))"
],
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/sAAAKsCAYAAACzhK3BAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xl8lOW9///XNZlMNgIkYJA1rAZk\nCyQBSaSigKIFl0pEakVUtK2ntfa0nl+P32NdvrXt6Xo8Pd9uLocunEoDVgp1RQUlQDHBoLIUTCQh\nsggkkIRJMpnM9fsjyyEkJCFkZpKZ9/Px4AFz39d93Z873JmZz7XdxlqLiIiIiIiIiIQOR7ADEBER\nEREREZHupWRfREREREREJMQo2RcREREREREJMUr2RUREREREREKMkn0RERERERGREKNkX0RERERE\nRCTEKNkXERHpIYwxK40x/3LW64eMMZ8ZY6qMMXHGmDnGmMLG1wuCGauIiIj0bMZaG+wYREREuoUx\npuqsl7FALVDf+PrL1tpV3Xiu8cBe4MxZm5+01v7oPOWPAn0b4/ECHwG/A56zbXwYG2PigHJgsrX2\nH43bcoHfW2t/013XISIiIqHJGewAREREuou1tk/Tv40xB4EV1tqNfjxl/dnn7IRrrbVbjDH9gWuA\n/wDSgK+2UXYw4GhK9BslA7u7Eqgxxmmt9XblWBEREel9NIxfRETChjEmxhjz/4wxR4wxpcaYHxtj\nIhv3LTDGfGyMecIYU2aM+cQYk+2POKy1p6y1LwJ3AF82xoxrjOEFY8y/GWMmA7uAiMYh+68YY0qB\nIcDrTSMYjDGJxpjfG2OOGmMOGWMeM8Y4Gvd9xRjzVuP1lgPfadz+ZWPMPxqv8W/GmKGN26ONMdYY\nc3/jVIFyY8zPz/n5PWCM2WeMqTTGfNgYJ8aY4caYdcaYE8aYImPMV/zxcxMREZHOU7IvIiLh5Alg\nCjCZhh71OcC/nLV/JOACLgXuA35njBnVTn0RxphPGxPtZ4wxiRcSjLX2XeAEcOU52z9sjK/eWtvH\nWnu9tXYY8BkNowOaRhOsAk4Do4EZwM3AnWdV9TmgABgI/NQYswR4CFgEDALeB/54TlgLgGnAdOBu\nY8wcAGPMncD/ByylYTrCYqDcGBMBvAxspaExYgHwiDHmqgv5WYiIiEj3UrIvIiLh5A7gMWvtCWvt\nMeB7tEyOvcAT1lpP4/D/jTQktW05QkNCPAKYSUPy/N9diOkwcEGNBADGmGQakvl/tta6rbVHgP8E\nbj+rWJG19hlrbb21thr4CvA9a+1+a20dDY0fVxpjBp11zPettRXW2k+Ad4DUxu0rGve9bxv8w1pb\nSkNDRbS19t8bf277afg5nB2HiIiIBJjm7IuISFgwxhgaeuyLz9pcDAw96/Vxa23NOfuHGGMuA3Y2\nbqux1g601p6moWcc4LAx5kHgY2NM9Dl1dGQoUHYh19IoGYgGjjdcGtDQiP/xWWUOtXHMr40x/++s\nbV5gGA0jBACOnrXPDTSNIhgOFJ4njpHGmFNnbYugoaFEREREgkTJvoiIhAVrrW1cET+Z/01aRwCf\nnlVs4DnJ+ghgS2NvdUcL8VnANP7pFGPMlcAAYEtnjznLIaAKSGhrNf+zYjr3mIettWvbiCW6E+cb\nQ+sk/hCwz1o7ueOQRUREJFA0jF9ERMLJn4DHjDEDjDFJwP+h5Zz1SOBRY4zLGHMNMB9olRgDGGNm\nGWPGmgZJNKys/3rjcPl2GWP6GWNubjz3s9baAxd6IY3D7LcDPzLGxBtjHMaYcY0NCOfza+DfjDEp\njXEkGGNu7eQpnwW+Y4yZ2njNlxljhtHYUGGMeahxkT+nMWaKMWb6hV6TiIiIdB8l+yIiEk6+C+yh\n4fF1BUAu8KOz9h+kYVj7UeB54G5rbdF56rqMhl7uqsa6TgHLOjh/00r6xcDDwA9omEffVUuB/sA+\nGqYCrKZh7YA2WWv/BPwX8KIxpqIx7vmdOZG19g/Az4A1QGXj3/0b5/7fAGTScF3HgV/R8UgIERER\n8SNz/pF/IiIi4cMYswD4L2vt2GDHIiIiInKx1LMvIiIiIiIiEmKU7IuIiIiIiIiEGA3jFxERERER\nEQkx6tkXERERERERCTFK9kVERERERERCjDPYAXTGwIED7ciRI4MdhoiIiIiIiPhBfn7+CWvtJcGO\nI5T0imR/5MiR5OXlBTsMERERERER8QNjTHGwYwg1GsYvIiIiIiIiEmKU7IuIiIiIiIiEGCX7IiIi\nIiIiIiGmV8zZb0tdXR2lpaXU1NQEOxSRTomOjmbYsGFERkYGOxQREREREQlxvTbZLy0tJT4+npEj\nR2KMCXY4Iu2y1nLy5ElKS0sZNWpUsMMREREREZEQ12uH8dfU1DBgwAAl+tIrGGMYMGCARqKIiIiI\niEhA+DXZN8Z8wxjzkTFmtzHmocZtqcaY7caYAmNMnjFmxkXU333BiviZ7lcREREREQkUvyX7xphJ\nwH3ADGAqsNAYMxb4EfCEtTYV+G7j616pT58+QTv3yJEjufXWW5tfr1mzhuXLl7d7TEFBAS+//HKb\n+5544gn+9V//tVX5CRMm4Ha7+fznP8/48eOZOHEi3/nOd9qsY+XKlRhj2LhxY/O2l156CWMMa9as\naTe2lStXcvjw4Q7LfO1rX2u3THs2bdrEwoUL29w3Z84cUlJSSE1NJTU1tcN4RUREREREejJ/9uxP\nAP5urXVba73AZuALgAX6NpbpB7Sf4XUD6/FQ8/bbVPz4x5x+4gkqfvxjat5+G+vx+PvUfpWfn8+e\nPXs6Xb69ZH/p0qWsXr26xbYXXniBpUuXAvDtb3+bffv28f7775Obm8srr7zSZj2TJ0/mhRdeaH79\npz/9ialTp3YYW2eSfX9btWoVBQUFFBQUsHjx4hb7rLX4fL4gRSYiIiIiInJh/JnsfwTMNsYMMMbE\nAjcAw4GHgB8bYw4BPwH+ta2DjTH3Nw7zzzt+/HiXg7AeD1XPPkvt1q1Yt7thm9tN7datVD37bLck\n/Js2beKqq67ipptuYvTo0XznO99h1apVzJgxg8mTJ1NYWAjA+vXrmTlzJtOmTWPevHkcO3YMgOPH\njzN//nwmTpzIihUrSE5O5sSJEwD88Y9/ZMaMGaSmpvLlL3+Z+vr65vN+61vf4qmnnmoVz5kzZ7jn\nnnuYMWMG06ZNY926dXg8Hr773e+yevVqUlNTWyX2l112GQkJCfz9739v3vbnP/+ZpUuXEhsby9VX\nXw2Ay+Vi+vTplJaWtvmzmD17Njt27KCuro6qqio+/vhjUlNTm/c/+eSTZGRkMGnSJO6//36staxZ\ns4a8vDzuuOMOUlNTqa6u5r333iMzM5OpU6cyY8YMKisrATh8+DALFixg3Lhx/Mu//Etzva+//jqz\nZs1i+vTpZGdnU1VVBcCrr77K+PHjmT59Oi+++GIn/0cbHDx4kJSUFJYtW8akSZM4dOhQp87z4IMP\nnncEgYiIiIiISCD4Ldm31u4F/h14HXgVKADqga8C37TWDge+CTx3nuN/a61Nt9amX3LJJV2OozY3\nF195OXi9LXd4vfjKy6nNze1y3WfbtWsXv/71r9m7dy9/+MMf2L9/Pzt27GDFihX84he/AODKK69k\n+/btvP/++9x+++386EcNMxieeOIJrrnmGnbv3s3ixYspKSkBYO/evaxevZrc3FwKCgqIiIhg1apV\nzee87bbb2LlzJx9//HGLWJ566imuueYaduzYwdtvv83DDz9MXV0dTz75JEuWLKGgoIAlS5a0uoal\nS5c298pv376dxMRExo0b16LMqVOnWL9+PXPnzm3z52CMYd68ebz22musW7eOG2+8scX+r33ta7z3\n3nt89NFHVFdXs2HDBhYvXkx6enpzz3pERARLlizh6aefZteuXWzcuJGYmBigYXTC6tWr+fDDD1m9\nejWHDh3ixIkTf
"text/plain": [
"<Figure size 1152x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-5:\n",
" Model: ig_resnext101_32x48d-224 -3.96%\n",
" Model: ig_resnext101_32x32d-224 -4.07%\n",
" Model: ig_resnext101_32x16d-224 -4.13%\n",
" Model: ig_resnext101_32x8d-224 -4.37%\n",
" Model: tf_efficientnet_b5-456 -5.44%\n",
" Model: pnasnet5large-331 -5.79%\n",
" Model: gluon_seresnext101_32x4d-224 -6.37%\n",
" Model: gluon_seresnext50_32x4d-224 -6.48%\n",
" Model: efficientnet_b2-260 -6.50%\n",
" Model: inception_resnet_v2-299 -6.61%\n",
" Model: dpn92-224 -7.33%\n",
" Model: gluon_resnet50_v1d-224 -7.34%\n",
" Model: dpn68b-224 -7.89%\n",
" Model: mobilenetv3_100-224 -8.21%\n",
"\n",
"Results by relative accuracy gap between ImageNet-Sketch and original ImageNet top-5:\n",
" Model: ig_resnext101_32x48d-224 -4.06%\n",
" Model: ig_resnext101_32x32d-224 -4.17%\n",
" Model: ig_resnext101_32x16d-224 -4.25%\n",
" Model: ig_resnext101_32x8d-224 -4.52%\n",
" Model: tf_efficientnet_b5-456 -5.63%\n",
" Model: pnasnet5large-331 -6.03%\n",
" Model: gluon_seresnext101_32x4d-224 -6.69%\n",
" Model: gluon_seresnext50_32x4d-224 -6.83%\n",
" Model: efficientnet_b2-260 -6.86%\n",
" Model: inception_resnet_v2-299 -6.94%\n",
" Model: dpn92-224 -7.73%\n",
" Model: gluon_resnet50_v1d-224 -7.76%\n",
" Model: dpn68b-224 -8.41%\n",
" Model: mobilenetv3_100-224 -8.85%\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pEHYrp_MJXAn",
"colab_type": "text"
},
"source": [
"# Best and Worst Predictions\n",
"We're going to re-run inference on one of our better models -- a ResNext101-32x16 pretrained on Instagram tags. We'll collect per-example losses and top-5 predictions and then display the results."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2uwWupygI8AM",
"colab_type": "code",
"colab": {}
},
"source": [
"# some code to display images in a grid and ground truth vs predictions for specified indices\n",
"from torchvision.utils import make_grid\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def show_img(ax, img):\n",
" npimg = img.numpy()\n",
" ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='bicubic')\n",
" \n",
"def show_summary(indices, dataset, nrows):\n",
" col_scale = len(indices) // nrows\n",
" top5_idx = mr['top5_idx'][indices]\n",
" top5_val = mr['top5_val'][indices]\n",
"\n",
" images = []\n",
" labels = []\n",
" filenames = []\n",
"\n",
" dataset.transform = transforms.Compose([\n",
" transforms.Resize(320, Image.BICUBIC),\n",
" transforms.CenterCrop(320),\n",
" transforms.ToTensor()])\n",
"\n",
" for i in indices:\n",
" img, label = dataset[i]\n",
" images.append(img)\n",
" labels.append(label)\n",
" filenames = dataset.filenames(list(indices), basename=True)\n",
"\n",
" fig = plt.figure(figsize=(10, 10 * col_scale), dpi=100)\n",
" ax = fig.add_subplot('111')\n",
" grid_best = make_grid(images, nrow=nrows, padding=10, normalize=True, scale_each=True)\n",
" show_img(ax, grid_best)\n",
" plt.show()\n",
"\n",
" for i, l in enumerate(labels):\n",
" print('{} ground truth = {}'.format(\n",
" id_to_synset[i] + '/' + filenames[i], id_to_text[l]))\n",
" print('Predicted:')\n",
" for pi, pv in zip(top5_idx[i], top5_val[i]):\n",
" if pv > 2e-5:\n",
" print(' {:.3f} {}'.format(100*pv, id_to_text[pi]))\n",
" print()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KoNaIzU9hPSV",
"colab_type": "code",
"colab": {}
},
"source": [
"# create mappings of label id to text and synset\n",
"!wget -q https://raw.githubusercontent.com/HoldenCaulfieldRye/caffe/master/data/ilsvrc12/synset_words.txt\n",
"with open('./synset_words.txt', 'r') as f:\n",
" split_lines = [l.strip().split(' ') for l in f.readlines()]\n",
" id_to_synset = dict(enumerate([l[0] for l in split_lines]))\n",
" id_to_text = dict(enumerate([' '.join(l[1:]) for l in split_lines]))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jNMG0pFcJMbZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "6fabbf01-5d78-492f-9bd4-5209530652f1"
},
"source": [
"BATCH_SIZE=128\n",
"mk, mr = runner(dict(model='ig_resnext101_32x32d'), dataset, device, collect_loss=True) "
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x32-e4b90b00.pth\n",
"100%|██████████| 1876573776/1876573776 [01:41<00:00, 18563785.10it/s]\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 224, 224)\n",
"\tinterpolation: bilinear\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 4.914 (4.914, 26.048/s) \tPrec@1 86.719 (86.719)\tPrec@5 99.219 (99.219)\n",
"Test: [20/79]\tTime 3.224 (3.202, 39.978/s) \tPrec@1 77.344 (81.436)\tPrec@5 94.531 (95.238)\n",
"Test: [40/79]\tTime 3.365 (3.246, 39.431/s) \tPrec@1 68.750 (80.526)\tPrec@5 89.844 (94.684)\n",
"Test: [60/79]\tTime 3.475 (3.309, 38.680/s) \tPrec@1 79.688 (78.279)\tPrec@5 91.406 (93.686)\n",
" * Prec@1 77.020 (22.980) Prec@5 93.340 (6.660)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bUcaOI0rJBoG",
"colab_type": "text"
},
"source": [
"# The Best Predictions\n",
"Harmonicas and Carbonara"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DE4UUv6aJD3l",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "c2da7a9c-4937-42ba-a206-57a64483f932"
},
"source": [
"nrows = 2\n",
"num_images = 10\n",
"best_idx = np.argsort(mr['losses_val'])[:num_images]\n",
"show_summary(best_idx, dataset, nrows)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAAewCAYAAAAbXsu/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvUmzLEmW5/U7OpiZT3d884vIjCEz\nq7rppopBEGHdG74BHwIWbFnSm2bFhg+AsGTLCkSAHYJ0QzMISA0ZkZExvXm4g/t1NzOdWKiaud+b\nkVWR2RWZUVL2f8/fvc/d3ExNTfXo+Z9JJaXEhAkTJkyYMGHChAkTJky4DfXHbsCECRMmTJgwYcKE\nCRMm/BgxkaUJEyZMmDBhwoQJEyZM+A5MZGnChAkTJkyYMGHChAkTvgMTWZowYcKECRMmTJgwYcKE\n78BEliZMmDBhwoQJEyZMmDDhOzCRpQkTJkyYMGHChAkTJkz4DkxkacKECRMmTJgwYcKECRO+AxNZ\nmjBhwoQJEyZMmDBhwoTvwESWJkyYMGHChAkTJkyYMOE7MJGlCRMmTJgwYcKECRMmTPgO/KjJkoj8\nJyLypYi0IvIvReQ/+GO3acKECRMmTJgwYcKECf8w8KMlSyLyHwP/FfBfAP8u8P8A/6OIPPijNmzC\nhAkTJkyYMGHChAn/ICAppT92G74TIvIvgf89pfSflv8r4Bvgv04p/Zff4/sCPAHWP2hDJ0yYMGHC\nhAkTJkyY8PcBK+B5+h0IkPkBG/N7Q0Qq4N8D/sXwXkopisj/BPyHv+U7NVAfvPUY+Ksfsp0TJkyY\nMGHChAkTJkz4e4UPgGff9+AfaxjePUADr+68/wp49Fu+858DVweviShNmDBhwoQJEyZMmDDhEL9T\n1NmPlSz9PvgXwPHB64M/bnMmTJgwYcKECRMmTJjw9xk/yjA84C0QgId33n8IvPyuL6SUOqAb/p9T\nliZMmDBhwoQJEyZMmDDh98OP0rOUUuqBfw38s+G9UuDhnwH/2x+rXRMmTJgwYcKECRMmTPiHgx+r\nZwly2fD/VkT+D+BfAf8ZsAD+mz9qqyZMmDBhwoQJEyZMmPAPAj9aspRS+u9E5D7wz8lFHf5v4D9K\nKd0t+jBhwoQJEyZMmDBhwoQJf+f40e6z9G8KETkiV8WbMGHChAkTJkyYMGHCBIDjlNL19z34R5mz\nNGHChAkTJkyYMGHChAl/bPxow/D+GJCDf28j3Tri8NfhnaH6nhx8I905+LBAX3bopfHnrcvcucZv\ntOb39AaKyJ02pPEy6vAzudsT6Tfb9p0XuNvQ3+M734Xf9Xb/zgoh/qErKuYbTUlIZWzEw2edhibt\n2yUiIMLd0ZeHVbp93MGJYiwj9LcMvf339ue91Rty9zhBycHoTyBDO37bvabfeGd/6t+oZinlr9zq\ng5QSKSViujOXynfymJf9/Lxz3qGfxmuXY4d7GU55605+p/lX2jXOqf1586MTbj+ZW//8Rpv/kFU+\nB/kQY9zLHBFyvx4eB/k+053+E+6Oxe8Ucb9xT+k3f/uOLpe7v6TcVu/997q/fGlBifqONvybQ+Rv\nHyr75/ndsnmU0aJ+gCYerlZ33j9Y31JKo4wZxkJKCaVUab98xzm4s+D95iUTKXfQofy61aSDL8nB\nVdLt9w+PTykRY8zjsLTvO7utPJsUIzHF8ZTCgUw9kBu/DX+TtnD3mFvvH8yH754Df8v32cuS4dP9\nXLkjz+58fxSfB7J/eKYplWfye659w5VlfGZlXA/tuHOf3/cqt6X6b3/vt33nLn734397S/+mVU5u\n/8OtToeDtetQW9yP93Ec/rb2HXzv9lMbJkz6jalHOe+dm/iOht+91v7A71yrf5dh8z2PvXXft1Sh\nsq6mlH+PkZTi97z474eJLBVUqqIxNUoJSo86wS2FVRA0oJWglCAKtFJU2jCrKmZ1TWUrQkrses/O\nB7xolG2wtsZog6hEiB7nO7q+pe06XNcTvCfFiIqpXDsrHkZrtNGICCFGeu9xwZf2FIGfEnFQxlJC\nih4cY8QHT4wBW2mOj5YsFg1KKdquZ3tzg8VzPtc8WFlW8wpTVYjWmUCRzyMpQkzDxW4J+HzZtF9c\nh58Hi8F+4T88RsafCUi3FOHhOwnZn4ThhIdzk6En7mpOd5TTW7gruA/OnRfwvYAfhNU4a4e23BJS\ngyJefpf9dZCs6IzHSrnXpJAEEhNCRIiEFHAx0QbFuo28vdxxsenxKGKC4AOIQhlLUhplLc1szmw2\no64blDYkEUIItG1LcD2kiDWa2hqMVgiRvm+5urpkd7Ml9I4YIiFGQoj52SqFUhqlFdYYqrrCGoOQ\n0ES0ysqlqapxLByvFpyu5sx1QtwW5VskelKKxKLApJSFWhwEW0oQIzEGYgiZwJX701WDtjaPxUHx\n0RZTz7DNAm1rkg9s12veXVzy+nrDdRtwIUIMKBKmamiOTlgcn7JcrmiG+9AGUYoQoe8drXM4Ekkr\n5rMZ905OOTs6xiihd46u6+h7h/eOFAKSYn4djj8AYh4NZXz4EGhdT+8DSRtM1aBtDSgSicoaVosF\nq/kMq1Q+vu/pvSemCAmsVtTGUFuDNQat9PdS4n53DOR1IHO5/ZvNhrfv3/H+6hofoZrNqWdzjLW5\nDTHS947N9oa22zGrLQ/Pzrh/ekZT1fgY2HYdrXP4sF/MBoKllMIajTEapSDGQIgBTyCmQIxxbNsg\nE7K4ERSCVjrLYiCGwJe/+pL/61/9n7Rt+7fesVaGB2ePOD+9h60atLEorUlAKO0ggVIKJWqUtRS5\nrFQhg2XBHkSEiKC0RkSRQiZvKUVEZfNBfn4KpTVKGYypaJoGay1KCSEGurZlt2vpe4fWhuPjY+aL\nBSTwIZQ5IUW5HciUFLIQRuKqtSKlSN/3ed4ipJgwxmBtRUoQgs9rnCrzTAlKaYxRQMI7l9dFgauL\nC168eIn3nocPH3J+fg8RTSjz2odAArQx4/2QhFg+V1ojSuFcz253Q/AOoy1VVWGrGmMMKSX63hGC\nRyuFNmZcc2OM9M4TE6OMystlxPUtlxfvePf2DaaqefjkA05OzzFl/VSjPM7yfbvb8vr1a169fMV6\nvUaApmmYzecslytWR0csFguMsTjnSUkw1gJ5bdVKMNbkeXCg0Xnvcc6BQGVtJuNASJEY8rgKIRCC\nJ4RAZfO9a63HNTCVdTYTUjXM0D0ZEZU/Q0CyXA0h4FxPt2txzqGVMJ/NqeqKlBIhlvlmDVpLlvsh\n0HY9Nzdbrq+uubq6pu16II+BfO0iE8o4z/+JB2pzHvlx0BFIKK2orGU2m7GYz5nNGpp5gzEmyzYR\nbGUxxqCUOiAP5Wx32MGwhiilMEqR4jCv8j0pnefofg7sx+NguDmcEyPRH4wAKRJ8IKY0zt1RLkLp\nCzm4X8a5N1wrk+44SlABjNG5v5XKzzUmYtEpQwhcX1/z/t171usNIQS00lRVTV1XVLaibhrmizlN\nU+cxFMI4/kIIRTdNhJDXUVXGhSprdAx5rKUUUcV4aLRGazPe/9CWvX6W+0iG8VXWghjCaFwYrm2t\noarqA/ksaGPRSufjYijn02OfhdImEcb3y5WzVC/zUylBKUUIAe/zK8Y87rwP3Gw23Gxu6Nodb1/+\niou33/6ghGkiSwWVtSxnC6xVGKMwem8d9TGRYhn8QK01xgiihUobllXNyWLB6WrFYjbHI1y1PZd9\npNM1qlnRNEuaqkbrSAgtbbtmvbnmenPNZnPDbrcldD0SI0oSSkBrRV3VzGYNWhv64Llpd2y7lhAT\nughpHyI9QEqoFFFFp/A+0HYtvW+ZzysePz7n4YNTrNFcXG148+YNs7TjT+9Z/umTOR88WDFfrRBb\noyShU0KSIMEjIZBiyKSJ/eIskrKwJg7cIn9eBEnijuWkKBuqLNyIZAJBXvxhEE9xT0oKqYBi+WM4\nVSINrHYkNoeRp
"text/plain": [
"<Figure size 1000x5000 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"n01440764/7.jpeg ground truth = burrito\n",
"Predicted:\n",
" 100.000 burrito\n",
"\n",
"n01443537/1.jpeg ground truth = carbonara\n",
"Predicted:\n",
" 100.000 carbonara\n",
"\n",
"n01484850/2.jpeg ground truth = carbonara\n",
"Predicted:\n",
" 100.000 carbonara\n",
"\n",
"n01491361/8.jpeg ground truth = washer, automatic washer, washing machine\n",
"Predicted:\n",
" 100.000 washer, automatic washer, washing machine\n",
"\n",
"n01494475/6.jpeg ground truth = rugby ball\n",
"Predicted:\n",
" 100.000 rugby ball\n",
"\n",
"n01496331/3.jpeg ground truth = harmonica, mouth organ, harp, mouth harp\n",
"Predicted:\n",
" 100.000 harmonica, mouth organ, harp, mouth harp\n",
"\n",
"n01498041/8.jpeg ground truth = frilled lizard, Chlamydosaurus kingi\n",
"Predicted:\n",
" 100.000 frilled lizard, Chlamydosaurus kingi\n",
"\n",
"n01514668/4.jpeg ground truth = lens cap, lens cover\n",
"Predicted:\n",
" 100.000 lens cap, lens cover\n",
"\n",
"n01514859/0.jpeg ground truth = bobsled, bobsleigh, bob\n",
"Predicted:\n",
" 100.000 bobsled, bobsleigh, bob\n",
"\n",
"n01518878/5.jpeg ground truth = harmonica, mouth organ, harp, mouth harp\n",
"Predicted:\n",
" 100.000 harmonica, mouth organ, harp, mouth harp\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VP7bb25UJFIh",
"colab_type": "text"
},
"source": [
"# The Worst Predictions\n",
"As usual, the worst predicitions are hard, in most cases due to issues with labelling or really challenging images. But hey, some of them are amusing. Who wouldn't want a pirate guinea pig? Pretty sure that's a marmot, not a beaver..."
]
},
{
"cell_type": "code",
"metadata": {
"id": "EdMgUWVeJGo0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "80d4e7ea-c98e-4b26-e55c-f16e4d671ece"
},
"source": [
"nrows = 2\n",
"num_images = 20\n",
"worst_idx = np.argsort(mr['losses_val'])[-num_images:][::-1]\n",
"show_summary(worst_idx, dataset, nrows)"
],
"execution_count": 18,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAA8lCAYAAABlGgIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvUmzLVmW3/Xbrbuf5javi6iszBJm\nWRKFSYDRCDPGGsA34EPAgCkjDCZixIQPgDGRwUfAjLlMtMKQCUxSFRSVGe2Ld5vTeLM7Bmu7n3Nf\nvKjKlJQWgcr/aTfjvnNP48d979X8138tV6UUVqxYsWLFihUrVqxYsWLFS+gf+wBWrFixYsWKFStW\nrFix4qeINVlasWLFihUrVqxYsWLFik9gTZZWrFixYsWKFStWrFix4hNYk6UVK1asWLFixYoVK1as\n+ATWZGnFihUrVqxYsWLFihUrPoE1WVqxYsWKFStWrFixYsWKT2BNllasWLFixYoVK1asWLHiE1iT\npRUrVqxYsWLFihUrVqz4BNZkacWKFStWrFixYsWKFSs+gTVZWrFixYoVK1asWLFixYpP4CedLCml\n/kOl1P+jlBqUUn9PKfXv/NjHtGLFihUrVqxYsWLFir8c+MkmS0qp/wD4L4H/DPg3gf8d+O+VUu9+\n1ANbsWLFihUrVqxYsWLFXwqoUsqPfQyfhFLq7wH/UynlP6r/1sCfAf9VKeW/+FEPbsWKFStWrFix\nYsWKFf/Cw/7YB/ApKKU88G8Bf3t+rJSSlVL/A/Dv/sBrGqD56OFXwIff1XGuWLFixYoVK1asWLHi\n/zfYA1+U36Ja9JNMloA3gAG+/ujxr4E/+oHX/CfAf/q7PKgVK1asWLFixYoVK1b8/xo/B379mz75\nJ9uz9E+Bvw3cXv38/Mc9nBUrVqxYsWLFihUrVvzEcPhtnvxTrSy9BxLw2UePfwZ89akXlFJGYJz/\nrZT6nR3cihUrVqxYsWLFihUr/sXHT7KyVEqZgP8F+FvzY3XAw98C/u6PdVwrVqxYsWLFihUrVqz4\ny4OfamUJZGz4f6OU+p+B/xH4j4Et8F//qEe1YsWKFStWrFixYsWKvxT4ySZLpZT/Tin1FvjPgc+B\nvw/8+6WUj4c+rFixYsWKFStWrFixYsU/d/xk77P0zwql1A3w9GMfx4oVK1asWLFixYoVK34yuC2l\nPP+mT/5J9iytWLFixYoVK1asWLFixY+NNVlasWLFihUrVqxYsWLFik/gJ9uz9GNCoeBq8rhIFRXf\nm0ZeYBYx/tQmlS/foRT+PKFl/Wp8+vA/8Z2XF8Kf/87/NLj+sDIfAah6nPNjHx3URUn6Fx3P99//\ne8+4eu/5fZWaH5//Vpbj+e1UrOXFb3/xklGfvjZKobVGvzjWTM6Fi6z26m/1/9XycrV8z1JK/Xl5\nfC8/W714p/l1f9GxK1VfUa6vm0Lr+b3V8vmXr6aWj/rU+tJaf/K4v7cEvnepf+B4lZJt8sPfYn7a\nJ96uvPj9+x98jU99wsvzOr/N5bPm/Xc5X/Onlqtzetkj9fHrvy175+rcKnW9kn/wy8+fUUr56DXw\ncuHXv6mXr1v+ev3H733/v3gDvTzG33bP/fnQWi9r6uP3fmlm1Ivv8c8iXy+lfOKc/HNwJtfX6JPH\n9/Jaf88G/eB1+vT7fPpw1ffs8w+97mOo7/0iy0wtH/LPcO0/8oOqboyP3+7j7/LJYy5XtvT6YP+C\nU/eDFuIHbN1LX/T9v18/9vL4Xn7IZa29/MP142I3uNiZajtexELl6j/13H3Cgi2PvjiXy3e8trff\n96PzXr/2T7r6O3W1vl8a5O/7rd8I1+f92oZf2dRrC3V15l6c71LfZbG/P2Bf589SdY/M9vnay/55\ny/v7ccD8HX74uR8f+2+7fb4fF3z/fV5clx8BBUgxknP+nX7OmixVeO/xzlMAozXGGAByLuScAIW1\nBmMMuRRySuSca7KhUfriJFThajNcjFHOmXy1+Y3RXCckWn2/0JdrQHjt1C/B0sWgxJRJKaKVwjmL\ndxatNTknpimQUgSlMdailCalRIiRlDKXNKBgtcZag9IWYyzWWUz93JwTJc8BVCalxBRCPVJFKfV4\nAaU0pWQoCdR8Ti2qvtdsGK4DRK2U/L1QzzlYY9DakHIipoRCybXyjlIKIQRSSlAKuRRSTKT6WnWJ\n1tHGoI2loMWo1Y2lqMGlKhhtsM6itSGXTIyZgqJpGzbbLW3boZQihIlhGBj6gWkayanMvlwCofmz\n6xdVSlGQdRRjJKVIoaCVwViD0Wa5rvXk1ffQOOdkLaqLudtud/z+7/8+96/uab2jxDNPD9/xq19/\nxcPTAa0bUI5UDDkj50MljEl4r7jZb9jtNmg05+PI08OBw/HMmAIxZ0DhfIO1Dmctxhg0BVUylEyK\nEzFOTNNELgWlLUVpcl2SViuskb1SciZG+b7eOfa7HXf3d+y2W5RSDEPP6XQixoB3jm7TAYppTIxT\nJMYgDtNovLO8enVP2zaczz2PT88cD0dO555Ur2fOmUJBGUkm87zv6k/Jsp+UUhhjsNZhrZV9XOa9\nNS9MhXeepm2w3qG0WvZxDIFpknMQQyDnhAK0MZWjKEuwUXIml/wiIFBKYayBUmRNxLisY2Ms2hic\n97TtBtc2+KbFNQ0hRob+wDQcQSm6zS2b7R1aG46HA0+P3xHDGWcK++2GrmlIsZCzwliPtp52s8G3\nHlUSYTjR973s1WrvxG5pYgycTmdCCBhjcM6ilGIaJ/qhp+SCcx5rGoy2aA2ZQM6BlCdyjlCgazd0\n7QaKJsZCKQptxL6EGEk51X1f6nnKlaAQOxOCnOOYAjFGYt3zOeUrO1gTw1yIKX3ayH8E7z2//OUv\nefP6dV0fiVRtDJRqn2UdNE1D2zYopcm5EFO8xGiqEgDMAVOWPV/tTEyJFCPWWAqZaZzQWuOdAyAl\nsVk5ZxQK6xzOOXk8Jpy15JIvAZzSFJBzhVwrbczyupwTcRgoKYKCGOLiM9LVucm5oKvfKqVgrcUa\nu+whax1KKXIW3zLbppzTCztrnUNV39Vttmy3O1KW9ysgvjIFcgzkIvtwttuq+kJrLNro5d/zZZ39\nk7FWjj9G8Rv5YnPnvT9/D2PM8nuMkSkEckqUAjFFrDFYa8nznlNyLpVSNM6hlRa7pjSqFMIwEmIg\na8ncyjjRxMLOOIzSsscVFHUhNZQCjWaJuI0ixEjMCe88nffoAiUlhmHgPPbkOTHLGWstbdPK9UyZ\ncRoZhkFsUuNJOXPue2JKxBSJMVb/o0glg1Yoq8GI/9JG/LG1BqM1GkXbNHjnGaeRw+FIiAHnPU3T\n0jYNm66j22yWdZpzkWMZR8ZplPVT90iMsjdzkWO31mGMxEVQSCkyTRPOSFwSU+LcD4zjRIiZYRgZ\nQ6zvUdhuNsSUOR+PfH53zx/+4hfsN1ustSgkLmh8A6XQ9z0pyT4x2ixrCBQ5JbQ2S7agtCKXTMqZ\nnDLTNFafnIhR1nPTNGw2G/G91lCy2GyjjcRUSpFQJAWRwjkGDlPPU3/iNI4oa3GtJ8W0+IcUIzmK\nbWkaz3634/Zmx2bT4b3HXq3ZeX1L8F+W2O/j5HlJKLM4rDLHDmre0x/FI0riqpznxyT+KTUGWog2\nrVHaEbLmcJ4YxhFVEkbJngsJpgwZhbOWbdvQeruQoXM8eZ0cXrKpciEpfpCZeklQzN9hec7165Ui\nhMA//Af/gH/8j/7R7zRhWpMlABSNb9nv9gBLYACqbqIIyCay1r4w9kUptJbgZk6MVBHnOW8uKKTq\nYHLOaK2wcxBagzpjDFYb2SR1ceRSiDlXR1Y3bqkBoUIMg9bklBmmiXEcMQq2Xct2u8F7SwgTx+NR\nAiKlsL5BG0fKpQZ7QRZiTpAS3mr5nq7B+pamaXDWorUip0gMUTZckuRlirFuREWMmZByjTMlSCx5\nQmvEkFlXkygxA
"text/plain": [
"<Figure size 1000x10000 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"n01440764/2.jpeg ground truth = bulbul\n",
"Predicted:\n",
" 100.000 mousetrap\n",
"\n",
"n01443537/6.jpeg ground truth = sarong\n",
"Predicted:\n",
" 100.000 crutch\n",
"\n",
"n01484850/4.jpeg ground truth = guinea pig, Cavia cobaya\n",
"Predicted:\n",
" 100.000 pirate, pirate ship\n",
"\n",
"n01491361/9.jpeg ground truth = beaver\n",
"Predicted:\n",
" 100.000 robin, American robin, Turdus migratorius\n",
"\n",
"n01494475/7.jpeg ground truth = doormat, welcome mat\n",
"Predicted:\n",
" 100.000 hay\n",
"\n",
"n01496331/6.jpeg ground truth = marmoset\n",
"Predicted:\n",
" 100.000 jackfruit, jak, jack\n",
"\n",
"n01498041/6.jpeg ground truth = goblet\n",
"Predicted:\n",
" 100.000 hip, rose hip, rosehip\n",
"\n",
"n01514668/5.jpeg ground truth = handkerchief, hankie, hanky, hankey\n",
"Predicted:\n",
" 100.000 rocking chair, rocker\n",
"\n",
"n01514859/8.jpeg ground truth = dial telephone, dial phone\n",
"Predicted:\n",
" 100.000 sewing machine\n",
"\n",
"n01518878/7.jpeg ground truth = hot pot, hotpot\n",
"Predicted:\n",
" 99.658 corn\n",
" 0.332 ear, spike, capitulum\n",
" 0.011 cucumber, cuke\n",
"\n",
"n01530575/3.jpeg ground truth = binder, ring-binder\n",
"Predicted:\n",
" 99.463 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin\n",
" 0.522 garbage truck, dustcart\n",
"\n",
"n01531178/8.jpeg ground truth = space bar\n",
"Predicted:\n",
" 100.000 maze, labyrinth\n",
"\n",
"n01532829/0.jpeg ground truth = syringe\n",
"Predicted:\n",
" 97.461 stethoscope\n",
" 2.556 lab coat, laboratory coat\n",
"\n",
"n01534433/7.jpeg ground truth = cornet, horn, trumpet, trump\n",
"Predicted:\n",
" 100.000 accordion, piano accordion, squeeze box\n",
"\n",
"n01537544/5.jpeg ground truth = sarong\n",
"Predicted:\n",
" 100.000 umbrella\n",
"\n",
"n01558993/9.jpeg ground truth = sandal\n",
"Predicted:\n",
" 100.000 park bench\n",
" 0.011 sunglass\n",
"\n",
"n01560419/2.jpeg ground truth = sweatshirt\n",
"Predicted:\n",
" 99.902 acoustic guitar\n",
" 0.088 pick, plectrum, plectron\n",
"\n",
"n01580077/5.jpeg ground truth = modem\n",
"Predicted:\n",
" 99.756 carton\n",
" 0.106 packet\n",
" 0.069 envelope\n",
" 0.015 binder, ring-binder\n",
" 0.014 tray\n",
"\n",
"n01582220/6.jpeg ground truth = jersey, T-shirt, tee shirt\n",
"Predicted:\n",
" 99.902 park bench\n",
" 0.010 neck brace\n",
" 0.009 gasmask, respirator, gas helmet\n",
" 0.005 soccer ball\n",
" 0.002 cowboy hat, ten-gallon hat\n",
"\n",
"n01592084/4.jpeg ground truth = Rottweiler\n",
"Predicted:\n",
" 99.951 malinois\n",
" 0.054 German shepherd, German shepherd dog, German police dog, alsatian\n",
" 0.002 Leonberg\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lQlFqynvbFwy",
"colab_type": "text"
},
"source": [
"# The Worst Predictions with Test Time Pooling\n",
"\n",
"Looking at the worst predictions above, there are a number of examples where the label for the image was for smaller, less obvious objects at the periphery of the scene (ie syringe at very edge of stethoscope, or trumpet being much less prominent than the accordion ). Seeing this I decided to run it again at a higher resolution, with Test Time Pooling enabled and a 100% crop. This results in a little over 1% boost in top-1 and top-5 and yes, those mentioned examples are no longer among the worst."
]
},
{
"cell_type": "code",
"metadata": {
"id": "U0nw7gjtNjU-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "c041e96d-8045-4d3b-c240-edd5c3d0976e"
},
"source": [
"# only doing this one if we're on a T4\n",
"if HAS_T4:\n",
" mk, mr = runner(dict(model='ig_resnext101_32x32d', img_size=288, ttp=True), dataset, device, collect_loss=True)\n",
" nrows = 2\n",
" num_images = 20\n",
" worst_idx = np.argsort(mr['losses_val'])[-num_images:][::-1]\n",
" show_summary(worst_idx, dataset, nrows)"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth\" to /root/.cache/torch/checkpoints/ig_resnext101_32x32-e4b90b00.pth\n",
"100%|██████████| 1876573776/1876573776 [01:35<00:00, 19579360.70it/s]\n",
"Applying test time pooling to model\n",
"Data processing configuration for current model + dataset:\n",
"\tinput_size: (3, 288, 288)\n",
"\tinterpolation: bicubic\n",
"\tmean: (0.485, 0.456, 0.406)\n",
"\tstd: (0.229, 0.224, 0.225)\n",
"\tcrop_pct: 0.875\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Test: [0/79]\tTime 36.513 (36.513, 3.506/s) \tPrec@1 87.500 (87.500)\tPrec@5 100.000 (100.000)\n",
"Test: [20/79]\tTime 5.699 (6.974, 18.353/s) \tPrec@1 79.688 (82.329)\tPrec@5 96.094 (95.945)\n",
"Test: [40/79]\tTime 5.764 (6.389, 20.033/s) \tPrec@1 70.312 (81.326)\tPrec@5 90.625 (95.312)\n",
"Test: [60/79]\tTime 5.792 (6.186, 20.694/s) \tPrec@1 81.250 (79.188)\tPrec@5 95.312 (94.365)\n",
" * Prec@1 78.100 (21.900) Prec@5 94.100 (5.900)\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAA8lCAYAAABlGgIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvUmzLVmW3/Xbrbuf5javi6iszBJm\nWRKFSYDRCDPGGsA34EPAgCkjDCZixIQPgDGRwUfAjLlMtMKQCUxSFRSVGe2Ld5vTeLM7Bmu7n3Nf\nvKjKlJQWgcr/aTfjvnNP48d979X8138tV6UUVqxYsWLFihUrVqxYsWLFS+gf+wBWrFixYsWKFStW\nrFix4qeINVlasWLFihUrVqxYsWLFik9gTZZWrFixYsWKFStWrFix4hNYk6UVK1asWLFixYoVK1as\n+ATWZGnFihUrVqxYsWLFihUrPoE1WVqxYsWKFStWrFixYsWKT2BNllasWLFixYoVK1asWLHiE1iT\npRUrVqxYsWLFihUrVqz4BNZkacWKFStWrFixYsWKFSs+gTVZWrFixYoVK1asWLFixYpP4CedLCml\n/kOl1P+jlBqUUn9PKfXv/NjHtGLFihUrVqxYsWLFir8c+MkmS0qp/wD4L4H/DPg3gf8d+O+VUu9+\n1ANbsWLFihUrVqxYsWLFXwqoUsqPfQyfhFLq7wH/UynlP6r/1sCfAf9VKeW/+FEPbsWKFStWrFix\nYsWKFf/Cw/7YB/ApKKU88G8Bf3t+rJSSlVL/A/Dv/sBrGqD56OFXwIff1XGuWLFixYoVK1asWLHi\n/zfYA1+U36Ja9JNMloA3gAG+/ujxr4E/+oHX/CfAf/q7PKgVK1asWLFixYoVK1b8/xo/B379mz75\nJ9uz9E+Bvw3cXv38/Mc9nBUrVqxYsWLFihUrVvzEcPhtnvxTrSy9BxLw2UePfwZ89akXlFJGYJz/\nrZT6nR3cihUrVqxYsWLFihUr/sXHT7KyVEqZgP8F+FvzY3XAw98C/u6PdVwrVqxYsWLFihUrVqz4\ny4OfamUJZGz4f6OU+p+B/xH4j4Et8F//qEe1YsWKFStWrFixYsWKvxT4ySZLpZT/Tin1FvjPgc+B\nvw/8+6WUj4c+rFixYsWKFStWrFixYsU/d/xk77P0zwql1A3w9GMfx4oVK1asWLFixYoVK34yuC2l\nPP+mT/5J9iytWLFixYoVK1asWLFixY+NNVlasWLFihUrVqxYsWLFik/gJ9uz9GNCoeBq8rhIFRXf\nm0ZeYBYx/tQmlS/foRT+PKFl/Wp8+vA/8Z2XF8Kf/87/NLj+sDIfAah6nPNjHx3URUn6Fx3P99//\ne8+4eu/5fZWaH5//Vpbj+e1UrOXFb3/xklGfvjZKobVGvzjWTM6Fi6z26m/1/9XycrV8z1JK/Xl5\nfC8/W714p/l1f9GxK1VfUa6vm0Lr+b3V8vmXr6aWj/rU+tJaf/K4v7cEvnepf+B4lZJt8sPfYn7a\nJ96uvPj9+x98jU99wsvzOr/N5bPm/Xc5X/Onlqtzetkj9fHrvy175+rcKnW9kn/wy8+fUUr56DXw\ncuHXv6mXr1v+ev3H733/v3gDvTzG33bP/fnQWi9r6uP3fmlm1Ivv8c8iXy+lfOKc/HNwJtfX6JPH\n9/Jaf88G/eB1+vT7fPpw1ffs8w+97mOo7/0iy0wtH/LPcO0/8oOqboyP3+7j7/LJYy5XtvT6YP+C\nU/eDFuIHbN1LX/T9v18/9vL4Xn7IZa29/MP142I3uNiZajtexELl6j/13H3Cgi2PvjiXy3e8trff\n96PzXr/2T7r6O3W1vl8a5O/7rd8I1+f92oZf2dRrC3V15l6c71LfZbG/P2Bf589SdY/M9vnay/55\ny/v7ccD8HX74uR8f+2+7fb4fF3z/fV5clx8BBUgxknP+nX7OmixVeO/xzlMAozXGGAByLuScAIW1\nBmMMuRRySuSca7KhUfriJFThajNcjFHOmXy1+Y3RXCckWn2/0JdrQHjt1C/B0sWgxJRJKaKVwjmL\ndxatNTknpimQUgSlMdailCalRIiRlDKXNKBgtcZag9IWYyzWWUz93JwTJc8BVCalxBRCPVJFKfV4\nAaU0pWQoCdR8Ti2qvtdsGK4DRK2U/L1QzzlYY9DakHIipoRCybXyjlIKIQRSSlAKuRRSTKT6WnWJ\n1tHGoI2loMWo1Y2lqMGlKhhtsM6itSGXTIyZgqJpGzbbLW3boZQihIlhGBj6gWkayanMvlwCofmz\n6xdVSlGQdRRjJKVIoaCVwViD0Wa5rvXk1ffQOOdkLaqLudtud/z+7/8+96/uab2jxDNPD9/xq19/\nxcPTAa0bUI5UDDkj50MljEl4r7jZb9jtNmg05+PI08OBw/HMmAIxZ0DhfIO1Dmctxhg0BVUylEyK\nEzFOTNNELgWlLUVpcl2SViuskb1SciZG+b7eOfa7HXf3d+y2W5RSDEPP6XQixoB3jm7TAYppTIxT\nJMYgDtNovLO8enVP2zaczz2PT88cD0dO555Ur2fOmUJBGUkm87zv6k/Jsp+UUhhjsNZhrZV9XOa9\nNS9MhXeepm2w3qG0WvZxDIFpknMQQyDnhAK0MZWjKEuwUXIml/wiIFBKYayBUmRNxLisY2Ms2hic\n97TtBtc2+KbFNQ0hRob+wDQcQSm6zS2b7R1aG46HA0+P3xHDGWcK++2GrmlIsZCzwliPtp52s8G3\nHlUSYTjR973s1WrvxG5pYgycTmdCCBhjcM6ilGIaJ/qhp+SCcx5rGoy2aA2ZQM6BlCdyjlCgazd0\n7QaKJsZCKQptxL6EGEk51X1f6nnKlaAQOxOCnOOYAjFGYt3zOeUrO1gTw1yIKX3ayH8E7z2//OUv\nefP6dV0fiVRtDJRqn2UdNE1D2zYopcm5EFO8xGiqEgDMAVOWPV/tTEyJFCPWWAqZaZzQWuOdAyAl\nsVk5ZxQK6xzOOXk8Jpy15JIvAZzSFJBzhVwrbczyupwTcRgoKYKCGOLiM9LVucm5oKvfKqVgrcUa\nu+whax1KKXIW3zLbppzTCztrnUNV39Vttmy3O1KW9ysgvjIFcgzkIvtwttuq+kJrLNro5d/zZZ39\nk7FWjj9G8Rv5YnPnvT9/D2PM8nuMkSkEckqUAjFFrDFYa8nznlNyLpVSNM6hlRa7pjSqFMIwEmIg\na8ncyjjRxMLOOIzSsscVFHUhNZQCjWaJuI0ixEjMCe88nffoAiUlhmHgPPbkOTHLGWstbdPK9UyZ\ncRoZhkFsUuNJOXPue2JKxBSJMVb/o0glg1Yoq8GI/9JG/LG1BqM1GkXbNHjnGaeRw+FIiAHnPU3T\n0jYNm66j22yWdZpzkWMZR8ZplPVT90iMsjdzkWO31mGMxEVQSCkyTRPOSFwSU+LcD4zjRIiZYRgZ\nQ6zvUdhuNsSUOR+PfH53zx/+4hfsN1ustSgkLmh8A6XQ9z0pyT4x2ixrCBQ5JbQ2S7agtCKXTMqZ\nnDLTNFafnIhR1nPTNGw2G/G91lCy2GyjjcRUSpFQJAWRwjkGDlPPU3/iNI4oa3GtJ8W0+IcUIzmK\nbWkaz3634/Zmx2bT4b3HXq3ZeX1L8F+W2O/j5HlJKLM4rDLHDmre0x/FI0riqpznxyT+KTUGWog2\nrVHaEbLmcJ4YxhFVEkbJngsJpgwZhbOWbdvQeruQoXM8eZ0cXrKpciEpfpCZeklQzN9hec7165Ui\nhMA//Af/gH/8j/7R7zRhWpMlABSNb9nv9gBLYACqbqIIyCay1r4w9kUptJbgZk6MVBHnOW8uKKTq\nYHLOaK2wcxBagzpjDFYb2SR1ceRSiDlXR1Y3bqkBoUIMg9bklBmmiXEcMQq2Xct2u8F7SwgTx+NR\nAiKlsL5BG0fKpQZ7QRZiTpAS3mr5nq7B+pamaXDWorUip0gMUTZckuRlirFuREWMmZByjTMlSCx5\nQmvEkFlXkygxA
"text/plain": [
"<Figure size 1000x10000 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"n01440764/2.jpeg ground truth = bulbul\n",
"Predicted:\n",
" 100.000 mousetrap\n",
"\n",
"n01443537/6.jpeg ground truth = sarong\n",
"Predicted:\n",
" 100.000 crutch\n",
"\n",
"n01484850/2.jpeg ground truth = pot, flowerpot\n",
"Predicted:\n",
" 100.000 Polaroid camera, Polaroid Land camera\n",
"\n",
"n01491361/7.jpeg ground truth = hot pot, hotpot\n",
"Predicted:\n",
" 99.902 corn\n",
" 0.100 ear, spike, capitulum\n",
"\n",
"n01494475/6.jpeg ground truth = jersey, T-shirt, tee shirt\n",
"Predicted:\n",
" 100.000 park bench\n",
"\n",
"n01496331/6.jpeg ground truth = goblet\n",
"Predicted:\n",
" 100.000 hip, rose hip, rosehip\n",
"\n",
"n01498041/6.jpeg ground truth = marmoset\n",
"Predicted:\n",
" 100.000 jackfruit, jak, jack\n",
"\n",
"n01514668/7.jpeg ground truth = custard apple\n",
"Predicted:\n",
" 100.000 ant, emmet, pismire\n",
"\n",
"n01514859/7.jpeg ground truth = corn\n",
"Predicted:\n",
" 100.000 hotdog, hot dog, red hot\n",
"\n",
"n01518878/5.jpeg ground truth = sarong\n",
"Predicted:\n",
" 100.000 umbrella\n",
"\n",
"n01530575/3.jpeg ground truth = wool, woolen, woollen\n",
"Predicted:\n",
" 100.000 doormat, welcome mat\n",
"\n",
"n01531178/9.jpeg ground truth = groom, bridegroom\n",
"Predicted:\n",
" 100.000 sombrero\n",
"\n",
"n01532829/8.jpeg ground truth = space bar\n",
"Predicted:\n",
" 99.951 maze, labyrinth\n",
" 0.015 joystick\n",
" 0.010 jigsaw puzzle\n",
"\n",
"n01534433/0.jpeg ground truth = theater curtain, theatre curtain\n",
"Predicted:\n",
" 99.951 altar\n",
" 0.045 throne\n",
" 0.003 monastery\n",
" 0.002 church, church building\n",
"\n",
"n01537544/5.jpeg ground truth = common iguana, iguana, Iguana iguana\n",
"Predicted:\n",
" 99.951 fountain\n",
" 0.007 triceratops\n",
" 0.006 pedestal, plinth, footstall\n",
" 0.003 palace\n",
"\n",
"n01558993/5.jpeg ground truth = modem\n",
"Predicted:\n",
" 99.951 carton\n",
" 0.050 packet\n",
" 0.008 tray\n",
" 0.003 crate\n",
"\n",
"n01560419/4.jpeg ground truth = handkerchief, hankie, hanky, hankey\n",
"Predicted:\n",
" 99.951 accordion, piano accordion, squeeze box\n",
" 0.017 stage\n",
" 0.012 unicycle, monocycle\n",
" 0.006 spatula\n",
" 0.003 plunger, plumber's helper\n",
"\n",
"n01580077/6.jpeg ground truth = sunglasses, dark glasses, shades\n",
"Predicted:\n",
" 100.000 volleyball\n",
"\n",
"n01582220/5.jpeg ground truth = swing\n",
"Predicted:\n",
" 100.000 carousel, carrousel, merry-go-round, roundabout, whirligig\n",
"\n",
"n01592084/9.jpeg ground truth = beaver\n",
"Predicted:\n",
" 100.000 robin, American robin, Turdus migratorius\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uFaD6fW1T_K3",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}