app.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import typing
  2. from fastapi import FastAPI
  3. from pydantic import BaseModel
  4. import gradio as gr
  5. from demo4.predict import predict
  6. app = FastAPI()
  7. class Request(BaseModel):
  8. question: str
  9. class Result(BaseModel):
  10. score: float
  11. title: str
  12. text: str
  13. class Response(BaseModel):
  14. results: typing.List[Result]
  15. @app.post("/predict", response_model=Response)
  16. async def predict_api(request: Request):
  17. results = predict(request.question)
  18. return Response(
  19. results=[
  20. Result(score=r["score"], title=r["title"], text=r["text"]) for r in results
  21. ]
  22. )
  23. def gradio_predict(question: str):
  24. results = predict(question)
  25. best_result = results[0]
  26. return f"{best_result['title']}\n\n{best_result['text']}", best_result["score"]
  27. demo = gr.Interface(
  28. fn=gradio_predict,
  29. inputs=gr.Textbox(
  30. label="Ask a question", placeholder="What is the capital of France?"
  31. ),
  32. outputs=[gr.Textbox(label="Answer"), gr.Number(label="Score")],
  33. allow_flagging="never",
  34. )
  35. # app, local_url, share_url=demo.launch()
  36. fastapp = gr.mount_gradio_app(app, demo, path="/")