Skip to content

Commit 94b7164

Browse files
committed
added better error handling to db functions
1 parent 3838fdc commit 94b7164

1 file changed

Lines changed: 63 additions & 53 deletions

File tree

src/synack/plugins/db.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
1111

12+
from pprint import pprint
1213
from pathlib import Path
1314
from sqlalchemy.orm import sessionmaker
1415
from synack.db.models import Target
@@ -61,17 +62,21 @@ def add_ips(self, results, session=None):
6162
session = self.Session()
6263
close = True
6364

64-
to_insert = [
65-
{'ip': result['ip'], 'target': result['target']}
66-
for result in results
67-
if result.get('ip') and result.get('target')
68-
]
65+
ips_data = list()
6966

70-
stmt = sqlite_insert(IP).values(to_insert)
71-
stmt = stmt.on_conflict_do_nothing(
72-
index_elements=['ip', 'target'],
73-
)
74-
session.execute(stmt)
67+
for result in results:
68+
if result.get('ip') and result.get('target'):
69+
ips_data.append({
70+
'ip': result['ip'],
71+
'target': result['target']
72+
})
73+
74+
if ips_data:
75+
stmt = sqlite_insert(IP).values(ips_data)
76+
stmt = stmt.on_conflict_do_nothing(
77+
index_elements=['ip', 'target'],
78+
)
79+
session.execute(stmt)
7580

7681
if close:
7782
session.commit()
@@ -84,17 +89,19 @@ def add_organizations(self, targets, session=None):
8489
session = self.Session()
8590
close = True
8691

87-
to_insert = list()
92+
organizations_data = list()
93+
8894
for target in targets:
8995
slug = target.get('organization_id', target.get('organization', {}).get('slug'))
9096
if slug:
91-
to_insert.append({'slug': slug})
97+
organizations_data.append({'slug': slug})
9298

93-
stmt = sqlite_insert(Organization).values(to_insert)
94-
stmt = stmt.on_conflict_do_nothing(
95-
index_elements=['slug'],
96-
)
97-
session.execute(stmt)
99+
if organizations_data:
100+
stmt = sqlite_insert(Organization).values(organizations_data)
101+
stmt = stmt.on_conflict_do_nothing(
102+
index_elements=['slug'],
103+
)
104+
session.execute(stmt)
98105

99106
if close:
100107
session.commit()
@@ -122,33 +129,34 @@ def add_ports(self, results):
122129
'updated': port.get('updated')
123130
})
124131

125-
stmt = sqlite_insert(Port).values(ports_data)
126-
stmt = stmt.on_conflict_do_update(
127-
index_elements=['port', 'protocol', 'ip', 'source'],
128-
set_={
129-
'service': stmt.excluded.service,
130-
'open': stmt.excluded.open,
131-
'updated': stmt.excluded.updated
132-
}
133-
)
134-
135-
session.execute(stmt)
132+
if ports_data:
133+
stmt = sqlite_insert(Port).values(ports_data)
134+
stmt = stmt.on_conflict_do_update(
135+
index_elements=['port', 'protocol', 'ip', 'source'],
136+
set_={
137+
'service': stmt.excluded.service,
138+
'open': stmt.excluded.open,
139+
'updated': stmt.excluded.updated
140+
}
141+
)
142+
session.execute(stmt)
143+
136144
session.commit()
137145
session.close()
138146

139147
def add_targets(self, targets, **kwargs):
140148
session = self.Session()
141149

142150
self.add_organizations(targets, session)
143-
db_orgs = session.query(Organization.slug).all()
151+
db_orgs = [org[0] for org in session.query(Organization.slug).all()]
144152

145153
targets_data = list()
146154

147155
for target in targets:
148156
org_slug = target.get('organization_id', target.get('organization', {}).get('slug'))
149157
if org_slug in db_orgs:
150158
target_data = {
151-
'slug': target.get('slug', target.get('id')),
159+
'slug': target.get('id', target.get('slug')),
152160
'category': target['category']['id'],
153161
'organization': org_slug,
154162
'date_updated': target.get('dateUpdated'),
@@ -160,21 +168,22 @@ def add_targets(self, targets, **kwargs):
160168
target_data.update(kwargs)
161169
targets_data.append(target_data)
162170

163-
stmt = sqlite_insert(Target).values(targets_data)
164-
stmt = stmt.on_conflict_do_update(
165-
index_elements=['slug'],
166-
set_={
167-
'category': stmt.excluded.category,
168-
'organization': stmt.excluded.organization,
169-
'date_updated': stmt.excluded.date_updated,
170-
'is_active': stmt.excluded.is_active,
171-
'is_new': stmt.excluded.is_new,
172-
'is_registered': stmt.excluded.is_registered,
173-
'last_submitted': stmt.excluded.last_submitted,
174-
}
175-
)
176-
177-
session.execute(stmt)
171+
if targets_data:
172+
stmt = sqlite_insert(Target).values(targets_data)
173+
stmt = stmt.on_conflict_do_update(
174+
index_elements=['slug'],
175+
set_={
176+
'category': stmt.excluded.category,
177+
'organization': stmt.excluded.organization,
178+
'date_updated': stmt.excluded.date_updated,
179+
'is_active': stmt.excluded.is_active,
180+
'is_new': stmt.excluded.is_new,
181+
'is_registered': stmt.excluded.is_registered,
182+
'last_submitted': stmt.excluded.last_submitted,
183+
}
184+
)
185+
session.execute(stmt)
186+
178187
session.commit()
179188
session.close()
180189

@@ -195,15 +204,16 @@ def add_urls(self, results):
195204
'screenshot_url': url.get('screenshot_url')
196205
})
197206

198-
stmt = sqlite_insert(Url).values(urls_data)
199-
stmt = stmt.on_conflict_do_update(
200-
index_elements=['ip', 'url'],
201-
set_={
202-
'screenshot_url': stmt.excluded.screenshot_url
203-
}
204-
)
207+
if urls_data:
208+
stmt = sqlite_insert(Url).values(urls_data)
209+
stmt = stmt.on_conflict_do_update(
210+
index_elements=['ip', 'url'],
211+
set_={
212+
'screenshot_url': stmt.excluded.screenshot_url
213+
}
214+
)
215+
session.execute(stmt)
205216

206-
session.execute(stmt)
207217
session.commit()
208218
session.close()
209219

0 commit comments

Comments
 (0)