image_inference.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python
  2. """
  3. @Contact : liuyuqi.gov@msn.cn
  4. @Time : 2024/03/25 11:34:35
  5. @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved.
  6. @Desc : image interfence
  7. """
  8. # import cv2
  9. import ssl
  10. import numpy as np
  11. import tensorflow as tf
  12. from PIL import Image
  13. from tensorflow.keras.applications.imagenet_utils import decode_predictions
  14. ssl._create_default_https_context = ssl._create_unverified_context
  15. classifier_model = None
  16. class ImageClassificationService:
  17. """ """
  18. def __init__(self):
  19. """ """
  20. # self.logger = logging.getLogger(__name__)
  21. self.model = ImageInferenceTask()
  22. self.image_classification_model = "INCEPTION_V2"
  23. self.IMAGE_SHAPE = (224, 224)
  24. async def classify(self, image_file):
  25. label = await self.model.predict(
  26. classifier_model_name=self.image_classification_model,
  27. image=image_file,
  28. shape=self.IMAGE_SHAPE,
  29. )
  30. return label
  31. class ImageInferenceTask:
  32. @staticmethod
  33. async def download_model(cls, classifier_model_name):
  34. pass
  35. @staticmethod
  36. async def load_model(cls, classifier_model_name):
  37. """ """
  38. if classifier_model_name == "INCEPTION_V3":
  39. model = tf.keras.applications.InceptionV3(weights="imagenet")
  40. else:
  41. model = tf.keras.applications.MobileNetV2(weights="imagenet")
  42. return model
  43. async def predict(self, image_path):
  44. """ """
  45. # image = cv2.imread(image_path)
  46. image = Image.open(image_path)
  47. image = np.array(image)
  48. image = tf.image.resize(image, (224, 224))
  49. image = tf.expand_dims(image, axis=0)
  50. image = image / 255.0
  51. predictions = classifier_model.predict(image)
  52. predictions = decode_predictions(predictions, top=1)
  53. return predictions