Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions edition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import importlib.abc
import importlib.machinery
import importlib.util
import os
import sys

EDITION = os.getenv("EDITION", "ce")

# ce is src/ itself; ee and cloud entries are override layers on top
_OVERRIDE_CHAINS = {
"ce": [],
"ee": ["ee"],
"cloud": ["cloud", "ee"],
}

_PACKAGES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "packages")


class _EditionFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
if not fullname.startswith("src.") or fullname == "src.edition":
return None

overrides = _OVERRIDE_CHAINS.get(EDITION, [])
if not overrides:
return None # ce — normal src/ resolution handles everything

rel = fullname[len("src."):].replace(".", os.sep)

for edition in overrides:
base = os.path.join(_PACKAGES_DIR, edition, "src")

init = os.path.join(base, rel, "__init__.py")
if os.path.exists(init):
return importlib.util.spec_from_file_location(
fullname,
init,
submodule_search_locations=[os.path.join(base, rel)],
)

module = os.path.join(base, rel + ".py")
if os.path.exists(module):
return importlib.util.spec_from_file_location(fullname, module)

# no override found — fall through to normal src/ resolution
return None


def install():
sys.meta_path.insert(0, _EditionFinder())
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import grpc

from edition import install as _install_edition
_install_edition()

from src.interceptor.auth_interceptor import AuthInterceptor
from src.logger import get_logger

Expand Down
13 changes: 13 additions & 0 deletions packages/cloud/src/schema/model_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List, Optional

from pydantic import BaseModel, Field


class Model(BaseModel):
identifier: str
name: str
capabilities: List[str] = Field(default=[])
provider: str
api: Optional[str] = Field(default=None)
auth: str
token_cost: Optional[float] = Field(default=1.0)
Empty file added packages/ee/src/.gitkeep
Empty file.
18 changes: 12 additions & 6 deletions src/endpoint/generation/generate_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,22 @@ def Prompt(self, request: pb2.PromptRequest, context) -> pb2.FlowResponse:
)
log.info(f"[Prompt] Generating flow...")

model = self.model_store.find(identifier=request.model_identifier)
t0 = time.time()
try:
generated_flow, completion = self.prompt_orchestrator.generate(
model=self.model_store.find(identifier=request.model_identifier),
model=model,
prompt=request.prompt,
few_shots=few_shots,
available_functions=combined_functions,
available_flow_types=combined_flow_types
)

elapsed = time.time() - t0
token_cost = model.token_cost if model.token_cost is not None else 1.0
usage = int(completion.usage.total_tokens * token_cost)
log.success( # type: ignore[attr-defined]
f"[Prompt] Generated '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens}"
f"[Prompt] Generated '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens} cost_factor={token_cost} usage={usage}"
)

current_time_ms = int(time.time() * 1000)
Expand All @@ -214,7 +217,7 @@ def Prompt(self, request: pb2.PromptRequest, context) -> pb2.FlowResponse:
)
),
cached_until=current_time_ms + 300000,
usage=completion.usage.total_tokens
usage=usage
)
except Exception as e:
elapsed = time.time() - t0
Expand Down Expand Up @@ -390,10 +393,11 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse:
)
log.info(f"[Flow] Modifying flow...")

model = self.model_store.find(identifier=request.model_identifier)
t0 = time.time()
try:
generated_flow, completion = self.flow_orchestrator.generate(
model=self.model_store.find(identifier=request.model_identifier),
model=model,
prompt=request.prompt,
flow=map_to_flow_schema(request.flow),
few_shots=few_shots,
Expand All @@ -402,8 +406,10 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse:
)

elapsed = time.time() - t0
token_cost = model.token_cost if model.token_cost is not None else 1.0
usage = int(completion.usage.total_tokens * token_cost)
log.success( # type: ignore[attr-defined]
f"[Flow] Modified '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens}"
f"[Flow] Modified '{generated_flow.name}' in {elapsed:.2f}s | tokens={completion.usage.total_tokens} cost_factor={token_cost} usage={usage}"
)

current_time_ms = int(time.time() * 1000)
Expand All @@ -416,7 +422,7 @@ def Flow(self, request: pb2.FlowRequest, context) -> pb2.FlowResponse:
)
),
cached_until=current_time_ms + 300000,
usage=completion.usage.total_tokens
usage=usage
)
except Exception as e:
elapsed = time.time() - t0
Expand Down
11 changes: 9 additions & 2 deletions src/schema/model_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator


class Model(BaseModel):
Expand All @@ -10,4 +10,11 @@ class Model(BaseModel):
provider: str
api: Optional[str] = Field(default=None)
auth: str
token_cost: Optional[int] = Field(default=1.0)
token_cost: Optional[float] = Field(default=1.0)

@field_validator("token_cost")
@classmethod
def token_cost_min_one(cls, v: Optional[float]) -> Optional[float]:
if v is not None and v < 1:
raise ValueError("token_cost must be >= 1")
return v