diff --git a/edition.py b/edition.py new file mode 100644 index 0000000..82971c5 --- /dev/null +++ b/edition.py @@ -0,0 +1,50 @@ +import importlib.abc +import importlib.machinery +import importlib.util +import os +import sys + +EDITION = os.getenv("EDITION", "ce") + +# ce is src/ itself; ee and cloud entries are override layers on top +_OVERRIDE_CHAINS = { + "ce": [], + "ee": ["ee"], + "cloud": ["cloud", "ee"], +} + +_PACKAGES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "packages") + + +class _EditionFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + if not fullname.startswith("src.") or fullname == "src.edition": + return None + + overrides = _OVERRIDE_CHAINS.get(EDITION, []) + if not overrides: + return None # ce — normal src/ resolution handles everything + + rel = fullname[len("src."):].replace(".", os.sep) + + for edition in overrides: + base = os.path.join(_PACKAGES_DIR, edition, "src") + + init = os.path.join(base, rel, "__init__.py") + if os.path.exists(init): + return importlib.util.spec_from_file_location( + fullname, + init, + submodule_search_locations=[os.path.join(base, rel)], + ) + + module = os.path.join(base, rel + ".py") + if os.path.exists(module): + return importlib.util.spec_from_file_location(fullname, module) + + # no override found — fall through to normal src/ resolution + return None + + +def install(): + sys.meta_path.insert(0, _EditionFinder()) diff --git a/main.py b/main.py index a2c93bc..2a91988 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,9 @@ import grpc +from edition import install as _install_edition +_install_edition() + from src.interceptor.auth_interceptor import AuthInterceptor from src.logger import get_logger diff --git a/packages/cloud/src/schema/model_schema.py b/packages/cloud/src/schema/model_schema.py new file mode 100644 index 0000000..9a5feef --- /dev/null +++ b/packages/cloud/src/schema/model_schema.py @@ -0,0 +1,13 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class Model(BaseModel): + identifier: str + name: str + capabilities: List[str] = Field(default=[]) + provider: str + api: Optional[str] = Field(default=None) + auth: str + token_cost: Optional[float] = Field(default=1.0) diff --git a/packages/ee/src/.gitkeep b/packages/ee/src/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/endpoint/generation/generate_endpoint.py b/src/endpoint/generation/generate_endpoint.py index a827865..de93d43 100644 --- a/src/endpoint/generation/generate_endpoint.py +++ b/src/endpoint/generation/generate_endpoint.py @@ -189,10 +189,11 @@ def Prompt(self, request: pb2.PromptRequest, context) -> pb2.FlowResponse: ) log.info(f"[Prompt] Generating flow...") + model = self.model_store.find(identifier=request.model_identifier) t0 = time.time() try: generated_flow, completion = self.prompt_orchestrator.generate( - model=self.model_store.find(identifier=request.model_identifier), + model=model, prompt=request.prompt, few_shots=few_shots, available_functions=combined_functions, @@ -200,8 +201,10 @@ def Prompt(self, request: pb2.PromptRequest, context) -> pb2.FlowResponse: ) elapsed = time.time() - t0 + token_cost = model.token_cost if model.token_cost is not None else 1.0 + usage = int(completion.usage.total_tokens * token_cost) log.success( # type: ignore[attr-defined] - f"[Prompt] Generated '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens}" + f"[Prompt] Generated '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens} cost_factor={token_cost} usage={usage}" ) current_time_ms = int(time.time() * 1000) @@ -214,7 +217,7 @@ def Prompt(self, request: pb2.PromptRequest, context) -> pb2.FlowResponse: ) ), cached_until=current_time_ms + 300000, - usage=completion.usage.total_tokens + usage=usage ) except Exception as e: elapsed = time.time() - t0 @@ -390,10 +393,11 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse: ) log.info(f"[Flow] Modifying flow...") + model = self.model_store.find(identifier=request.model_identifier) t0 = time.time() try: generated_flow, completion = self.flow_orchestrator.generate( - model=self.model_store.find(identifier=request.model_identifier), + model=model, prompt=request.prompt, flow=map_to_flow_schema(request.flow), few_shots=few_shots, @@ -402,8 +406,10 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse: ) elapsed = time.time() - t0 + token_cost = model.token_cost if model.token_cost is not None else 1.0 + usage = int(completion.usage.total_tokens * token_cost) log.success( # type: ignore[attr-defined] - f"[Flow] Modified '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens}" + f"[Flow] Modified '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens} cost_factor={token_cost} usage={usage}" ) current_time_ms = int(time.time() * 1000) @@ -416,7 +422,7 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse: ) ), cached_until=current_time_ms + 300000, - usage=completion.usage.total_tokens + usage=usage ) except Exception as e: elapsed = time.time() - t0 diff --git a/src/schema/model_schema.py b/src/schema/model_schema.py index 55655f6..fe0642a 100644 --- a/src/schema/model_schema.py +++ b/src/schema/model_schema.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class Model(BaseModel): @@ -10,4 +10,11 @@ class Model(BaseModel): provider: str api: Optional[str] = Field(default=None) auth: str - token_cost: Optional[int] = Field(default=1.0) + token_cost: Optional[float] = Field(default=1.0) + + @field_validator("token_cost") + @classmethod + def token_cost_min_one(cls, v: Optional[float]) -> Optional[float]: + if v is not None and v < 1: + raise ValueError("token_cost must be >= 1") + return v