Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions orca/orca.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,25 @@ def _collect_variables(names, expressions=None):
if '.' in expression:
# Registered variable expression refers to column.
table_name, column_name = expression.split('.')
table = get_table(table_name)
variables[label] = table.get_column(column_name)

if column_name == '*':
# return a table view with all columns
variables[label] = get_table_view(table_name)
elif column_name == 'local':
# return a table view with just the local columns
variables[label] = get_table_view(table_name, 'local')
else:
# return a single column
table = get_table(table_name)
variables[label] = table.get_column(column_name)

elif '[' in expression and expression.endswith(']'):
# evaluate a subset of columns
table_name, cols = expression[:-1].split('[')
cols = cols.split(',')
cols = map(str.strip, cols)
variables[label] = get_table_view(table_name, cols)

else:
thing = all_variables[expression]
if isinstance(thing, (_InjectableFuncWrapper, TableFuncWrapper)):
Expand Down Expand Up @@ -1069,6 +1086,45 @@ def get_table(table_name):
return table


def get_table_view(table_name, columns=None):
"""
Get a view of the registered table.

Parameters
----------
table_name: str
Name of the registered orca table.
columns: str or list, optional, default None
Subset of columns to collect.
Use the 'local' keyword to fetch all local columns.

Returns
-------
pandas.DataFrame with extensions:
- wrapper: returns the orca table for the view.
- update_col: updates or adds a column to the wrapper.
- update_col_from_series: updates a column in the wrapper from a series.

"""
wrapper = get_table(table_name)

# handle local keyword
if columns == 'local':
columns = wrapper.local_columns
elif isinstance(columns, list) and 'local' in columns:
columns = wrapper.local_columns + columns
columns.remove('local')

# evaluate the table
df = wrapper.to_frame(columns)

# add extension methods and return
df.wrapper = wrapper
df.update_col = wrapper.update_col
df.update_col_from_series = wrapper.update_col_from_series
return df


def table_type(table_name):
"""
Returns the type of a registered table.
Expand Down
89 changes: 89 additions & 0 deletions orca/tests/test_orca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import tempfile

import numpy as np
import pandas as pd
import pytest
from pandas.util import testing as pdt
Expand Down Expand Up @@ -375,6 +376,48 @@ def test_update_col(df):
wrapped['a'], pd.Series([1, 99, 3], index=df.index, name='a'))


def test_table_view():
# register a table
orca.add_table(
'test_tab',
pd.DataFrame({
'a': [1, 2],
'b': [3, 4]
})
)

# attach columns
@orca.column('test_tab')
def c():
return pd.Series([5, 6])

@orca.column('test_tab')
def d():
return pd.Series([7, 8])

# test 1 - all columns
tv1 = orca.get_table_view('test_tab')
assert (tv1.values.flatten() == [1, 3, 5, 7, 2, 4, 6, 8]).all()

# test 2 - just local
tv2 = orca.get_table_view('test_tab', 'local')
assert (tv2.values.flatten() == [1, 3, 2, 4]).all()

# test 3 - specific columns
tv3 = orca.get_table_view('test_tab', ['a', 'c'])
assert (tv3.values.flatten() == [1, 5, 2, 6]).all()

# test 4 - local + specific extra cols
tv4 = orca.get_table_view('test_tab', ['local', 'd'])
assert (tv4.values.flatten() == [1, 3, 7, 2, 4, 8]).all()

# test updating the wrapper's data
tv4.update_col('a', 0)
tv4.update_col_from_series('b', pd.Series([-1], index=pd.Index([1])))
tv5 = orca.get_table_view('test_tab', 'local')
assert (tv5.values.flatten() == [0, 3, 0, -1]).all()


class _FakeTable(object):
def __init__(self, name, columns):
self.name = name
Expand Down Expand Up @@ -1256,3 +1299,49 @@ def inj2():
assert filename.endswith('test_orca.py')
assert isinstance(lineno, int)
assert 'def inj2()' in source


def test_table_view_expressions():

# define a data frame and add some columns to it
orca.add_table(
'my_df',
pd.DataFrame(
{
'a': [1, 1, 1],
'b': [2, 2, 2]
}
)
)

@orca.column('my_df')
def c():
return pd.Series([3, 3, 3])

@orca.column('my_df')
def d():
return pd.Series([4, 4, 4])

# case 1 -- evaluate all columns
@orca.table()
def test1(df='my_df.*'):
return df * -1

assert (orca.get_table_view('test1').values.flatten() ==
np.tile([-1, -2, -3, -4], 3)).all()

# case 2 -- just local
@orca.table()
def test2(df='my_df.local'):
return df * -1

assert (orca.get_table_view('test2').values.flatten() ==
np.tile([-1, -2], 3)).all()

# case 3 -- specific columns
@orca.table()
def test3(df='my_df[a, d]'):
return df * -1

assert (orca.get_table_view('test3').values.flatten() ==
np.tile([-1, -4], 3)).all()