12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- #!/usr/bin/env python
- """
- @Contact : liuyuqi.gov@msn.cn
- @Time : 2024/03/25 11:34:35
- @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved.
- @Desc : image interfence
- """
- # import cv2
- import ssl
- import numpy as np
- import tensorflow as tf
- from PIL import Image
- from tensorflow.keras.applications.imagenet_utils import decode_predictions
- ssl._create_default_https_context = ssl._create_unverified_context
- classifier_model = None
- class ImageClassificationService:
- """ """
- def __init__(self):
- """ """
- # self.logger = logging.getLogger(__name__)
- self.model = ImageInferenceTask()
- self.image_classification_model = "INCEPTION_V2"
- self.IMAGE_SHAPE = (224, 224)
- async def classify(self, image_file):
- label = await self.model.predict(
- classifier_model_name=self.image_classification_model,
- image=image_file,
- shape=self.IMAGE_SHAPE,
- )
- return label
- class ImageInferenceTask:
- @staticmethod
- async def download_model(cls, classifier_model_name):
- pass
- @staticmethod
- async def load_model(cls, classifier_model_name):
- """ """
- if classifier_model_name == "INCEPTION_V3":
- model = tf.keras.applications.InceptionV3(weights="imagenet")
- else:
- model = tf.keras.applications.MobileNetV2(weights="imagenet")
- return model
- async def predict(self, image_path):
- """ """
- # image = cv2.imread(image_path)
- image = Image.open(image_path)
- image = np.array(image)
- image = tf.image.resize(image, (224, 224))
- image = tf.expand_dims(image, axis=0)
- image = image / 255.0
- predictions = classifier_model.predict(image)
- predictions = decode_predictions(predictions, top=1)
- return predictions
|