image_inference.py 980 B

1234567891011121314151617181920212223242526272829303132333435
  1. #!/usr/bin/env python
  2. """
  3. @Contact : liuyuqi.gov@msn.cn
  4. @Time : 2024/03/25 11:34:35
  5. @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved.
  6. @Desc : image interfence
  7. """
  8. from typing import Any
  9. from app.models.item import ItemsOut
  10. from app.service.image_inference import ImageClassificationService
  11. from fastapi import APIRouter, File, UploadFile
  12. router = APIRouter()
  13. image_inference = ImageClassificationService()
  14. @router.get("/predict", response_model=ItemsOut)
  15. async def predict(file: UploadFile = File(...)) -> Any:
  16. """
  17. Predict image category.
  18. """
  19. extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
  20. if not extension:
  21. return "Image must be jpg or png format!"
  22. # logger.info('Image Classification')
  23. image = await BasicImageUtils.read_image_file(
  24. await file.read(), filename=file.filename, cache=True
  25. )
  26. image_category = await image_inference.classify(image)
  27. return image_category