#!/usr/bin/env python """ @Contact : liuyuqi.gov@msn.cn @Time : 2024/03/25 11:34:35 @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved. @Desc : image interfence """ # import cv2 import ssl import numpy as np import tensorflow as tf from PIL import Image from tensorflow.keras.applications.imagenet_utils import decode_predictions ssl._create_default_https_context = ssl._create_unverified_context classifier_model = None class ImageClassificationService: """ """ def __init__(self): """ """ # self.logger = logging.getLogger(__name__) self.model = ImageInferenceTask() self.image_classification_model = "INCEPTION_V2" self.IMAGE_SHAPE = (224, 224) async def classify(self, image_file): label = await self.model.predict( classifier_model_name=self.image_classification_model, image=image_file, shape=self.IMAGE_SHAPE, ) return label class ImageInferenceTask: @staticmethod async def download_model(cls, classifier_model_name): pass @staticmethod async def load_model(cls, classifier_model_name): """ """ if classifier_model_name == "INCEPTION_V3": model = tf.keras.applications.InceptionV3(weights="imagenet") else: model = tf.keras.applications.MobileNetV2(weights="imagenet") return model async def predict(self, image_path): """ """ # image = cv2.imread(image_path) image = Image.open(image_path) image = np.array(image) image = tf.image.resize(image, (224, 224)) image = tf.expand_dims(image, axis=0) image = image / 255.0 predictions = classifier_model.predict(image) predictions = decode_predictions(predictions, top=1) return predictions