Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run streamlit dashboard on docker container #1571

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
schema: str,
warehouse: str,
role: str,
host: Optional[str] = None,
init_server_side: bool = True,
database_redact_keys: bool = False,
database_prefix: Optional[str] = None,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
warehouse,
role,
database_args["database_prefix"],
host,
)

def _initialize_snowflake_server_side_feedback_evaluations(
Expand All @@ -103,6 +105,7 @@ def _initialize_snowflake_server_side_feedback_evaluations(
warehouse: str,
role: str,
database_prefix: str,
host: Optional[str] = None,
):
connection_parameters = {
"account": account,
Expand All @@ -112,6 +115,7 @@ def _initialize_snowflake_server_side_feedback_evaluations(
"schema": schema,
"warehouse": warehouse,
"role": role,
**({"host": host} if host else {})
}
with Session.builder.configs(connection_parameters).create() as session:
ServerSideEvaluationArtifacts(
Expand Down
4 changes: 3 additions & 1 deletion src/dashboard/trulens/dashboard/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run_dashboard(
address: Optional[str] = None,
force: bool = False,
_dev: Optional[Path] = None,
spcs_runtime: Optional[bool] = False,
) -> Process:
"""Run a streamlit dashboard to view logged results and apps.

Expand Down Expand Up @@ -125,7 +126,8 @@ def run_dashboard(
"--database-prefix",
session.connector.db.table_prefix,
]

if spcs_runtime:
args.append("--spcs-runtime")
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
Expand Down
51 changes: 48 additions & 3 deletions src/dashboard/trulens/dashboard/streamlit_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import sys

import os
from trulens.core import TruSession
from trulens.core.database import base as mod_db
from snowflake.snowpark import Session
from snowflake.sqlalchemy import URL


def init_from_args():
Expand All @@ -16,6 +18,7 @@ def init_from_args():
parser.add_argument(
"--database-prefix", default=mod_db.DEFAULT_DATABASE_PREFIX
)
parser.add_argument("--spcs-runtime", default=False)

try:
args = parser.parse_args()
Expand All @@ -27,6 +30,48 @@ def init_from_args():
# so we have to do a hard exit.
sys.exit(e.code)

TruSession(
database_url=args.database_url, database_prefix=args.database_prefix
def get_login_token():
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
with open("/snowflake/session/token", "r") as f:
return f.read()

connection_params = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT"),
"host": os.getenv("SNOWFLAKE_HOST"),
"authenticator": "oauth",
"token": get_login_token(),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
"database": os.environ.get("SNOWFLAKE_DATABASE"),
"schema": os.environ.get("SNOWFLAKE_SCHEMA"),
}
snowpark_session = Session.builder.configs(connection_params).create()

# Set up sqlalchemy engine parameters.
conn = snowpark_session.connection
engine_params = {}
engine_params["paramstyle"] = "qmark"
engine_params["creator"] = lambda: conn
database_args = {"engine_params": engine_params}
db_url = URL(
account=snowpark_session.get_current_account(),
user=snowpark_session.get_current_user(),
password="password",
database=snowpark_session.get_current_database(),
schema=snowpark_session.get_current_schema(),
warehouse=snowpark_session.get_current_warehouse(),
role=snowpark_session.get_current_role(),
)
if args.spcs_runtime:
TruSession(
database_url=db_url,
database_check_revision=False,
database_args=database_args,
database_prefix=args.database_prefix
)
else:
TruSession(
database_url=args.database_url, database_prefix=args.database_prefix
)
12 changes: 12 additions & 0 deletions tools/snowflake/spcs_dashboard/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
ARG BASE_IMAGE=python:3.11.9-slim-bullseye
FROM $BASE_IMAGE

COPY ./ /trulens_dashboard/

WORKDIR /trulens_dashboard

RUN pip install -r requirements.txt
RUN pip install trulens_connectors_snowflake-1.0.1-py3-none-any.whl
RUN pip install trulens_dashboard-1.0.1-py3-none-any.whl
Comment on lines +9 to +10
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary, once the next version of trulens is released, then this can go in requirments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check with corey whether requirements.txt can be done using poetry instead.


CMD ["python", "run_dashboard.py"]
7 changes: 7 additions & 0 deletions tools/snowflake/spcs_dashboard/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python-dotenv
pydantic
snowflake[ml]
snowflake-connector-python
snowflake-sqlalchemy
trulens
# trulens-connectors-snowflake
103 changes: 103 additions & 0 deletions tools/snowflake/spcs_dashboard/run_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
from snowflake.snowpark import Session
from snowflake.snowpark.functions import call_builtin

# Read environment variables
account = os.getenv('SNOWFLAKE_ACCOUNT')
user = os.getenv('SNOWFLAKE_USER')
password = os.getenv('SNOWFLAKE_PASSWORD')
role = os.getenv('SNOWFLAKE_ROLE')
warehouse = os.getenv('SNOWFLAKE_WAREHOUSE')
database = os.getenv('SNOWFLAKE_DATABASE')
schema = os.getenv('SNOWFLAKE_SCHEMA')

# Define Snowflake connection parameters
connection_parameters = {
"account": account,
"user": user,
"password": password,
"role": role,
"warehouse": warehouse,
"database": database,
"schema": schema
}

# Create a Snowflake session
session = Session.builder.configs(connection_parameters).create()

# Create compute pool if it does not exist
compute_pool_name = input("Enter compute pool name: ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using input() for getting the compute pool name is not suitable for automated scripts or production environments. Consider using environment variables or command-line arguments instead.

compute_pools = session.sql("SHOW COMPUTE POOLS").collect()
compute_pool_exists = any(pool['name'] == compute_pool_name.upper() for pool in compute_pools)
if compute_pool_exists:
print(f"Compute pool {compute_pool_name} already exists")
else:
session.sql(f"CREATE COMPUTE POOL {compute_pool_name} MIN_NODES = 1 MAX_NODES = 1 INSTANCE_FAMILY = CPU_X64_M").collect()
session.sql(f"DESCRIBE COMPUTE POOL {compute_pool_name}").collect()

# Create image repository
image_repository_name = f"trulens_image_repository"
session.sql(f"CREATE IMAGE REPOSITORY {image_repository_name}").collect()
session.sql("SHOW IMAGE REPOSITORIES").collect()

# Create network rule
network_rule_name = f"{compute_pool_name}_allow_all_network_rule"
session.sql(f"CREATE OR REPLACE NETWORK RULE {network_rule_name} TYPE = 'HOST_PORT' MODE = 'EGRESS' VALUE_LIST = ('0.0.0.0:443','0.0.0.0:80')").collect()
session.sql("SHOW NETWORK RULES").collect()

# Create external access integration
access_integration_name = f"{compute_pool_name}_access_integration"
session.sql(f"CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION {access_integration_name} ALLOWED_NETWORK_RULES = ({network_rule_name}) ENABLED = true").collect()
session.sql("SHOW EXTERNAL ACCESS INTEGRATIONS").collect()

app_name = "trulens_dashboard"
secret_name = f"{schema}.{app_name}_login_credentials"
session.sql(f"CREATE SECRET {secret_name} TYPE=password USERNAME={user} PASSWORD='{password}'").collect()

service_name=compute_pool_name+"_trulens_dashboard"
session.sql("""
CREATE SERVICE {service_name}
IN COMPUTE POOL {compute_pool_name}
EXTERNAL_ACCESS_INTEGRATIONS = ({access_integration_name})
FROM SPECIFICATION $$
spec:
containers:
- name: {container_name}
image: /{database}/{schema}/{container_name}/{app_name}:latest
env:
SNOWFLAKE_ACCOUNT: "{account}"
SNOWFLAKE_DATABASE: "{database}"
SNOWFLAKE_SCHEMA: "{schema}"
SNOWFLAKE_WAREHOUSE: "{warehouse}"
SNOWFLAKE_ROLE: "{role}"
RUN_DASHBOARD: "1"
secrets:
- snowflakeSecret: {secret_name}
secretKeyRef: username
envVarName: SNOWFLAKE_USER
- snowflakeSecret: {secret_name}
secretKeyRef: password
envVarName: SNOWFLAKE_PASSWORD
endpoints:
- name: trulens-demo-dashboard-endpoint
port: 8484
public: true
$$
""".format(service_name=service_name,
compute_pool_name=compute_pool_name,
access_integration_name=access_integration_name,
container_name=app_name+"_container",
account=account,
database=database,
schema=schema,
warehouse=warehouse,
role=role,
app_name=app_name)).collect()

# Show services and get their status
session.sql(f"SHOW ENDPOINTS IN SERVICE {service_name}").collect()
session.sql("CALL SYSTEM$GET_SERVICE_STATUS('dkurokawa_trulens_demo_app')").collect()
session.sql("CALL SYSTEM$GET_SERVICE_STATUS('dkurokawa_trulens_demo_dashboard')").collect()

# Close the session
session.close()
69 changes: 69 additions & 0 deletions tools/snowflake/spcs_dashboard/run_dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from trulens.dashboard import run_dashboard
from trulens.core import TruSession
from snowflake.snowpark import Session
from trulens.connectors.snowflake import SnowflakeConnector
from snowflake.sqlalchemy import URL
import os

# connection_params = {
# "account": os.environ.get("SNOWFLAKE_ACCOUNT"),
# "host": os.getenv("SNOWFLAKE_HOST"),
# "user": os.environ.get("SNOWFLAKE_USER"),
# "password": os.environ.get("SNOWFLAKE_PASSWORD"),
# "database": os.environ.get("SNOWFLAKE_DATABASE"),
# "schema": os.environ.get("SNOWFLAKE_SCHEMA"),
# "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
# "role": os.environ.get("SNOWFLAKE_ROLE"),
# "init_server_side": False,
# }

# connector = SnowflakeConnector(**connection_params)
# session = TruSession(connector=connector, init_server_side=False)

def get_login_token():
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
with open("/snowflake/session/token", "r") as f:
return f.read()

connection_params = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT"),
"host": os.getenv("SNOWFLAKE_HOST"),
"authenticator": "oauth",
"token": get_login_token(),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
"database": os.environ.get("SNOWFLAKE_DATABASE"),
"schema": os.environ.get("SNOWFLAKE_SCHEMA"),
}
snowpark_session = Session.builder.configs(connection_params).create()

# Set up sqlalchemy engine parameters.
conn = snowpark_session.connection
engine_params = {}
engine_params["paramstyle"] = "qmark"
engine_params["creator"] = lambda: conn
database_args = {"engine_params": engine_params}
# # Ensure any Cortex provider uses the only Snowflake connection allowed in this stored procedure.
# trulens.providers.cortex.provider._SNOWFLAKE_STORED_PROCEDURE_CONNECTION = (
# conn
# )
# Run deferred feedback evaluator.
db_url = URL(
account=snowpark_session.get_current_account(),
user=snowpark_session.get_current_user(),
password="password",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing passwords directly in the code is a security risk. Consider using environment variables or a secure vault to manage sensitive information.

database=snowpark_session.get_current_database(),
schema=snowpark_session.get_current_schema(),
warehouse=snowpark_session.get_current_warehouse(),
role=snowpark_session.get_current_role(),
)
tru_session = TruSession(
database_url=db_url,
database_check_revision=False, # TODO: check revision in the future?
database_args=database_args,
)
tru_session.get_records_and_feedback()

run_dashboard(tru_session, port=8484, spcs_runtime=True)