{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deploying PyTorch in Python via a REST API with Flask\n=====================================================\n\n**Author**: [Avinash Sajjanshetty](https://avi.im)\n\nIn this tutorial, we will deploy a PyTorch model using Flask and expose\na REST API for model inference. In particular, we will deploy a\npretrained DenseNet 121 model which detects the image.\n\n```{=html}\n
TIP:
\n```\n```{=html}\n
\n```\n```{=html}\n

All the code used here is released under MIT license and is available on Github.

\n```\n```{=html}\n
\n```\nThis represents the first in a series of tutorials on deploying PyTorch\nmodels in production. Using Flask in this way is by far the easiest way\nto start serving your PyTorch models, but it will not work for a use\ncase with high performance requirements. For that:\n\n> - If you\\'re already familiar with TorchScript, you can jump\n> straight into our [Loading a TorchScript Model in\n> C++](https://pytorch.org/tutorials/advanced/cpp_export.html)\n> tutorial.\n> - If you first need a refresher on TorchScript, check out our [Intro\n> a\n> TorchScript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)\n> tutorial.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "API Definition\n==============\n\nWe will first define our API endpoints, the request and response types.\nOur API endpoint will be at `/predict` which takes HTTP POST requests\nwith a `file` parameter which contains the image. The response will be\nof JSON response containing the prediction:\n\n``` {.sh}\n{\"class_id\": \"n02124075\", \"class_name\": \"Egyptian_cat\"}\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dependencies\n============\n\nInstall the required dependencies by running the following command:\n\n``` {.sh}\npip install Flask==2.0.1 torchvision==0.10.0\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simple Web Server\n=================\n\nFollowing is a simple web server, taken from Flask\\'s documentation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from flask import Flask\napp = Flask(__name__)\n\n\n@app.route('/')\ndef hello():\n return 'Hello World!'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also change the response type, so that it returns a JSON\nresponse containing ImageNet class id and name. The updated `app.py`\nfile will be now:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from flask import Flask, jsonify\napp = Flask(__name__)\n\n@app.route('/predict', methods=['POST'])\ndef predict():\n return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inference\n=========\n\nIn the next sections we will focus on writing the inference code. This\nwill involve two parts, one where we prepare the image so that it can be\nfed to DenseNet and next, we will write the code to get the actual\nprediction from the model.\n\nPreparing the image\n-------------------\n\nDenseNet model requires the image to be of 3 channel RGB image of size\n224 x 224. We will also normalize the image tensor with the required\nmean and standard deviation values. You can read more about it\n[here](https://pytorch.org/vision/stable/models.html).\n\nWe will use `transforms` from `torchvision` library and build a\ntransform pipeline, which transforms our images as required. You can\nread more about transforms\n[here](https://pytorch.org/vision/stable/transforms.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import io\n\nimport torchvision.transforms as transforms\nfrom PIL import Image\n\ndef transform_image(image_bytes):\n my_transforms = transforms.Compose([transforms.Resize(255),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize(\n [0.485, 0.456, 0.406],\n [0.229, 0.224, 0.225])])\n image = Image.open(io.BytesIO(image_bytes))\n return my_transforms(image).unsqueeze(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above method takes image data in bytes, applies the series of\ntransforms and returns a tensor. To test the above method, read an image\nfile in bytes mode (first replacing\n[../\\_static/img/sample\\_file.jpeg]{.title-ref} with the actual path to\nthe file on your computer) and see if you get a tensor back:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with open(\"../_static/img/sample_file.jpeg\", 'rb') as f:\n image_bytes = f.read()\n tensor = transform_image(image_bytes=image_bytes)\n print(tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prediction\n==========\n\nNow will use a pretrained DenseNet 121 model to predict the image class.\nWe will use one from `torchvision` library, load the model and get an\ninference. While we\\'ll be using a pretrained model in this example, you\ncan use this same approach for your own models. See more about loading\nyour models in this\n`tutorial `{.interpreted-text\nrole=\"doc\"}.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchvision import models\n\n# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:\nmodel = models.densenet121(weights='IMAGENET1K_V1')\n# Since we are using our model only for inference, switch to `eval` mode:\nmodel.eval()\n\n\ndef get_prediction(image_bytes):\n tensor = transform_image(image_bytes=image_bytes)\n outputs = model.forward(tensor)\n _, y_hat = outputs.max(1)\n return y_hat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tensor `y_hat` will contain the index of the predicted class id.\nHowever, we need a human readable class name. For that we need a class\nid to name mapping. Download [this\nfile](https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json)\nas `imagenet_class_index.json` and remember where you saved it (or, if\nyou are following the exact steps in this tutorial, save it in\n[tutorials/\\_static]{.title-ref}). This file contains the mapping of\nImageNet class id to ImageNet class name. We will load this JSON file\nand get the class name of the predicted index.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import json\n\nimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))\n\ndef get_prediction(image_bytes):\n tensor = transform_image(image_bytes=image_bytes)\n outputs = model.forward(tensor)\n _, y_hat = outputs.max(1)\n predicted_idx = str(y_hat.item())\n return imagenet_class_index[predicted_idx]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before using `imagenet_class_index` dictionary, first we will convert\ntensor value to a string value, since the keys in the\n`imagenet_class_index` dictionary are strings. We will test our above\nmethod:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with open(\"../_static/img/sample_file.jpeg\", 'rb') as f:\n image_bytes = f.read()\n print(get_prediction(image_bytes=image_bytes))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should get a response like this:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "['n02124075', 'Egyptian_cat']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first item in array is ImageNet class id and second item is the\nhuman readable name.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Integrating the model in our API Server\n\n: \n\n ------------------------------------------------------------------------\n\n In this final part we will add our model to our Flask API server.\n Since our API server is supposed to take an image file, we will\n update our `predict` method to read files from the requests:\n\n ``` {.python}\n from flask import request\n\n @app.route('/predict', methods=['POST'])\n def predict():\n if request.method == 'POST':\n # we will get the file from the request\n file = request.files['file']\n # convert that to bytes\n img_bytes = file.read()\n class_id, class_name = get_prediction(image_bytes=img_bytes)\n return jsonify({'class_id': class_id, 'class_name': class_name})\n ```\n\n\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\n\n: The `app.py` file is now complete. Following is the full version;\n replace the paths with the paths where you saved your files and it\n should run:\n\n ``` {.python}\n import io\n import json\n\n from torchvision import models\n import torchvision.transforms as transforms\n from PIL import Image\n from flask import Flask, jsonify, request\n\n\n app = Flask(__name__)\n imagenet_class_index = json.load(open('/imagenet_class_index.json'))\n model = models.densenet121(weights='IMAGENET1K_V1')\n model.eval()\n\n\n def transform_image(image_bytes):\n my_transforms = transforms.Compose([transforms.Resize(255),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize(\n [0.485, 0.456, 0.406],\n [0.229, 0.224, 0.225])])\n image = Image.open(io.BytesIO(image_bytes))\n return my_transforms(image).unsqueeze(0)\n\n\n def get_prediction(image_bytes):\n tensor = transform_image(image_bytes=image_bytes)\n outputs = model.forward(tensor)\n _, y_hat = outputs.max(1)\n predicted_idx = str(y_hat.item())\n return imagenet_class_index[predicted_idx]\n\n\n @app.route('/predict', methods=['POST'])\n def predict():\n if request.method == 'POST':\n file = request.files['file']\n img_bytes = file.read()\n class_id, class_name = get_prediction(image_bytes=img_bytes)\n return jsonify({'class_id': class_id, 'class_name': class_name})\n\n\n if __name__ == '__main__':\n app.run()\n ```\n\n\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\n\n: Let\\'s test our web server! Run:\n\n ``` {.sh}\n FLASK_ENV=development FLASK_APP=app.py flask run\n ```\n\n\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\n\n: We can use the [requests](https://pypi.org/project/requests/)\n library to send a POST request to our app:\n\n ``` {.python}\n import requests\n\n resp = requests.post(\"http://localhost:5000/predict\",\n files={\"file\": open('/cat.jpg','rb')})\n ```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Printing [resp.json()]{.title-ref} will now show the following:\n\n> ``` {.sh}\n> {\"class_id\": \"n02124075\", \"class_name\": \"Egyptian_cat\"}\n> ```\n\n\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\\#\n\n: Next steps\n ==========\n\n The server we wrote is quite trivial and may not do everything you\n need for your production application. So, here are some things you\n can do to make it better:\n\n - The endpoint `/predict` assumes that always there will be a\n image file in the request. This may not hold true for all\n requests. Our user may send image with a different parameter or\n send no images at all.\n - The user may send non-image type files too. Since we are not\n handling errors, this will break our server. Adding an explicit\n error handing path that will throw an exception would allow us\n to better handle the bad inputs\n - Even though the model can recognize a large number of classes of\n images, it may not be able to recognize all images. Enhance the\n implementation to handle cases when the model does not recognize\n anything in the image.\n - We run the Flask server in the development mode, which is not\n suitable for deploying in production. You can check out [this\n tutorial](https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/)\n for deploying a Flask server in production.\n - You can also add a UI by creating a page with a form which takes\n the image and displays the prediction.\n - In this tutorial, we only showed how to build a service that\n could return predictions for a single image at a time. We could\n modify our service to be able to return predictions for multiple\n images at once. In addition, the\n [service-streamer](https://github.com/ShannonAI/service-streamer)\n library automatically queues requests to your service and\n samples them into mini-batches that can be fed into your model.\n You can check out [this\n tutorial](https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer).\n - Finally, we encourage you to check out our other tutorials on\n deploying PyTorch models linked-to at the top of the page.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }