Skip to content

Commit 2ad37e1

Browse files
committed
update to newer version of seaborn
1 parent 558c40f commit 2ad37e1

6 files changed

Lines changed: 31 additions & 31 deletions

File tree

statsplot/dimred.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ def __init__(
133133
self, data, method=PCA, transformation=None, n_components=None, **kargs
134134
):
135135

136-
if n_components is None:
137-
n_components = data.shape[0]
138-
139136
if data.shape[0] > data.shape[1]:
140-
print(
137+
warnings.warn(
141138
"you don't need to reduce dimensionality or your dataset is transposed."
142139
)
143140

141+
if n_components is None:
142+
n_components = min(data.shape)
143+
144144
self.decomposition = method(n_components=n_components, **kargs)
145145

146146
self.rawdata = data
@@ -154,7 +154,7 @@ def __init__(
154154

155155
else:
156156

157-
self.data_ = data.applymap(transformation)
157+
self.data_ = data.map(transformation)
158158

159159
Xt = self.decomposition.fit_transform(self.data_)
160160

statsplot/plot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,10 @@ def vulcanoplot(
189189
ax.annotate(g1, (ax_lim * 0.9, 0), ha="right")
190190
ax.annotate(g2, (-ax_lim * 0.9, 0), ha="left")
191191

192+
192193
# TODO: handle unaligned input.
193194

195+
194196
def statsplot(
195197
variable,
196198
test_variable,
@@ -261,7 +263,7 @@ def statsplot(
261263
if show_dots:
262264
legend = ax.get_legend_handles_labels()
263265

264-
sns.swarmplot(**params, color="k", dodge=True, **swarm_params)
266+
sns.swarmplot(**params, palette="dark:k", dodge=True, **swarm_params)
265267

266268
# add old variable, as dots have unified colors
267269
if grouping_variable is not None:

statsplot/siglabels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __plot_sig_labels_hue(
7575
# start with y0
7676
y = y0
7777

78-
for idx, text in P_values.iteritems():
78+
for idx, text in P_values.items():
7979

8080
def calculate_hue_offset(group, order):
8181
return (order.index(group) - len(order) * 0.5 + 0.5) / len(order) * width
@@ -112,7 +112,7 @@ def ___plot_sig_labels_xaxis(
112112
P_values = P_values.apply(format_p_value, use_stars=use_stars)
113113

114114
y = y0
115-
for idx, text in P_values.iteritems():
115+
for idx, text in P_values.items():
116116

117117
def calculate_x_offset(group, order):
118118
return order.index(group)
@@ -135,7 +135,6 @@ def plot_all_sig_labels(
135135
ax=None,
136136
**kws,
137137
):
138-
139138
""""""
140139

141140
# define y0 and deltay

statsplot/stats.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __stats_test_all_on_once(values1, values2, test, **test_kws):
7777
res = test(values1, values2, **test_kws)
7878
ResultsDB["Statistic"] = res.statistic
7979
ResultsDB["Pvalue"] = res.pvalue
80+
8081
return ResultsDB
8182

8283

@@ -109,9 +110,9 @@ def two_group_test(
109110
test_kws=None,
110111
correct_for_multiple_testing=True,
111112
):
112-
113113
"""test: a parwise statistical test found in scipy e.g ['mannwhitneyu','ttest_ind']
114-
or a function wich takes two argumens. Additional keyword arguments can be specified by test_kws"""
114+
or a function wich takes two argumens. Additional keyword arguments can be specified by test_kws
115+
"""
115116

116117
# Define test
117118
if test_kws is None:
@@ -189,6 +190,9 @@ def two_group_test(
189190
Pairwise_comp = Test(values1, values2)
190191

191192
Pairwise_comp["median_diff"] = values2.median() - values1.median()
193+
Pairwise_comp["mean_diff"] = values2.mean() - values1.mean()
194+
Pairwise_comp["Median1"] = values1.median()
195+
Pairwise_comp["Median2"] = values2.median()
192196

193197
if min_value >= 0:
194198
Pairwise_comp["log2FC"] = np.log2(values2.mean() + log_delta) - np.log2(
@@ -230,7 +234,6 @@ def calculate_stats(
230234
test="ttest_ind",
231235
**test_kws,
232236
):
233-
234237
"""Calculate pairewise statistical tests optioonally grouped by a grouping variable"""
235238

236239
kws = dict(

statsplot/statstable.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,14 @@ def subset(self, index=None, columns=None):
128128

129129
def groupby(self, groupby, axis=0):
130130
if axis == 0:
131-
G = self.obs.groupby(groupby, axis=0)
131+
G = self.obs.groupby(groupby)
132132

133133
for group in G.indices:
134134
yield (group, self.subset(index=self.obs_names[G.indices[group]]))
135135

136136
elif axis == 1:
137137
# Group by on axis 0. var indexes contain data.columns
138-
G = self.var.groupby(groupby, axis=0)
138+
G = self.var.groupby(groupby)
139139
for group in G.indices:
140140
yield group, self.subset(columns=self.var_names[G.indices[group]])
141141
else:
@@ -397,8 +397,7 @@ def __get_comparisons(self, subset=None):
397397

398398
return list(subset)
399399

400-
401-
# TODO: Hide output axes labesls
400+
# TODO: Hide output axes labesls
402401
def vulcanoplot(
403402
self,
404403
comparisons=None,

statsplot/transformations.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
2-
3-
41
from numpy import log
52
import pandas as pd
63
import numpy as np
74

85

9-
# copied from scikit-bio
6+
# copied from scikit-bio
107
# because I cannot install it
118
def closure(mat):
129
"""
@@ -93,24 +90,25 @@ def multiplicative_replacement(mat, delta=None):
9390
[ 0.0625, 0.4375, 0.4375, 0.0625]])
9491
"""
9592
mat = closure(mat)
96-
z_mat = (mat == 0)
93+
z_mat = mat == 0
9794

9895
num_feats = mat.shape[-1]
9996
tot = z_mat.sum(axis=-1, keepdims=True)
10097

10198
if delta is None:
102-
delta = (1. / num_feats)**2
99+
delta = (1.0 / num_feats) ** 2
103100

104101
zcnts = 1 - tot * delta
105102
if np.any(zcnts) < 0:
106-
raise ValueError('The multiplicative replacement created negative '
107-
'proportions. Consider using a smaller `delta`.')
103+
raise ValueError(
104+
"The multiplicative replacement created negative "
105+
"proportions. Consider using a smaller `delta`."
106+
)
108107
mat = np.where(z_mat, delta, zcnts * mat)
109108
return mat.squeeze()
110109

111110

112-
113-
def clr(data: pd.DataFrame, log=log,features="all"):
111+
def clr(data: pd.DataFrame, log=log, features="all"):
114112
"""
115113
Centered log ratio (CLR) with multiplicative replacement implemented in scikit-bio
116114
"""
@@ -128,6 +126,8 @@ def clr(data: pd.DataFrame, log=log,features="all"):
128126
# Fill in zeros with multiplicative replacement
129127
matrix = multiplicative_replacement(matrix)
130128

129+
matrix = pd.DataFrame(matrix, index=d.index, columns=d.columns)
130+
131131
# CLR
132132
matrix = log(matrix)
133133

@@ -137,7 +137,7 @@ def clr(data: pd.DataFrame, log=log,features="all"):
137137
mean = matrix.mean(1)
138138

139139
elif features.lower() == "nz":
140-
140+
141141
mean = matrix[matrix != 0].mean(1)
142142
elif features.lower() == "iql":
143143
# use mean of features in interquartile range
@@ -147,9 +147,6 @@ def clr(data: pd.DataFrame, log=log,features="all"):
147147
else:
148148
raise Exception("features must be 'all', 'nz', or 'iql'")
149149

150-
151150
matrix = (matrix.T - mean).T
152151

153-
if type(data) == pd.DataFrame:
154-
155-
return pd.DataFrame(matrix, index=d.index, columns=d.columns)
152+
return matrix

0 commit comments

Comments
 (0)