fix: hf hosted inference check (#1128)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
client = HuggingFaceHubLLM(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user