summaryrefslogtreecommitdiff
path: root/raphodo/rpdsql.py
diff options
context:
space:
mode:
authorAntoine Beaupré <anarcat@debian.org>2017-12-30 12:18:30 -0500
committerAntoine Beaupré <anarcat@debian.org>2017-12-30 12:18:30 -0500
commit88c8bd4cd2ee4707f8a43be4d89c4e040dcced2f (patch)
tree01b10a0d80509730942706333f173c5aa7f239eb /raphodo/rpdsql.py
parentc5fc6c6030d7d9d1b2af3d5165bebed3decd741b (diff)
New upstream version 0.9.6upstream/0.9.6
Diffstat (limited to 'raphodo/rpdsql.py')
-rwxr-xr-xraphodo/rpdsql.py132
1 files changed, 112 insertions, 20 deletions
diff --git a/raphodo/rpdsql.py b/raphodo/rpdsql.py
index f65a02b..112aea0 100755
--- a/raphodo/rpdsql.py
+++ b/raphodo/rpdsql.py
@@ -148,7 +148,8 @@ class ThumbnailRowsSQL:
extensions: Optional[List[str]]=None,
proximity_col1: Optional[List[int]]=None,
proximity_col2: Optional[List[int]]=None,
- exclude_scan_ids: Optional[List[int]]=None) -> Tuple[str, Tuple[Any]]:
+ exclude_scan_ids: Optional[List[int]]=None,
+ uids: Optional[List[bytes]]=None,) -> Tuple[str, List[Any]]:
where_clauses = []
where_values = []
@@ -187,6 +188,17 @@ class ThumbnailRowsSQL:
where_clauses.append('extension IN ({})'.format(','.join('?' * len(extensions))))
where_values.extend(extensions)
+ if uids is not None:
+ if len(uids) == 1:
+ where_clauses.append('uid=?')
+ where_values.append(uids[0])
+ else:
+ # assume max host parameters in a single SQL statement is 999
+ if len(uids) > 900:
+ uids = uids[:900]
+ where_clauses.append('uid IN ({})'.format(','.join('?' * len(uids))))
+ where_values.extend(uids)
+
if exclude_scan_ids is not None:
if len(exclude_scan_ids) == 1:
where_clauses.append(('scan_id!=?'))
@@ -216,7 +228,16 @@ class ThumbnailRowsSQL:
where_clauses.append('({})'.format(' OR '.join(or_clauses)))
where = ' AND '.join(where_clauses)
- return (where, where_values)
+ return where, where_values
+
+ def _build_sort(self, sort_by: Sort, sort_order: Qt.SortOrder) -> str:
+ if sort_by == Sort.modification_time:
+ sort = 'ORDER BY mtime {}'.format(self.sort_order_map[sort_order])
+ else:
+ sort = 'ORDER BY {0} {1}, mtime {1}'.format(
+ self.sort_map[sort_by], self.sort_order_map[sort_order]
+ )
+ return sort
def get_view(self, sort_by: Sort,
sort_order: Qt.SortOrder,
@@ -224,15 +245,11 @@ class ThumbnailRowsSQL:
proximity_col1: Optional[List[int]] = None,
proximity_col2: Optional[List[int]] = None) -> List[Tuple[bytes, bool]]:
- where, where_values = self._build_where(show=show,
- proximity_col1=proximity_col1,
- proximity_col2=proximity_col2)
+ where, where_values = self._build_where(
+ show=show, proximity_col1=proximity_col1, proximity_col2=proximity_col2
+ )
- if sort_by == Sort.modification_time:
- sort = 'ORDER BY mtime {}'.format(self.sort_order_map[sort_order])
- else:
- sort = 'ORDER BY {0} {1}, mtime {1}'.format(self.sort_map[sort_by],
- self.sort_order_map[sort_order])
+ sort = self._build_sort(sort_by, sort_order)
query = 'SELECT uid, marked FROM files'
@@ -251,6 +268,38 @@ class ThumbnailRowsSQL:
logging.debug('%s', query)
return self.conn.execute(query).fetchall()
+ def get_first_uid_from_uid_list(self, sort_by: Sort,
+ sort_order: Qt.SortOrder,
+ show: Show,
+ uids: List[bytes],
+ proximity_col1: Optional[List[int]] = None,
+ proximity_col2: Optional[List[int]] = None) -> Optional[bytes]:
+ """
+ Given a list of uids, and sort and filtering criteria, return the first
+ uid that the user will have displayed -- if any are displayed.
+ """
+
+ where, where_values = self._build_where(
+ show=show, proximity_col1=proximity_col1, proximity_col2=proximity_col2, uids=uids
+ )
+
+ sort = self._build_sort(sort_by, sort_order)
+
+ query = 'SELECT uid FROM files'
+
+ if sort_by == Sort.device:
+ query = '{} NATURAL JOIN devices'.format(query)
+
+ query = '{} WHERE {}'.format(query, where)
+
+ query = '{} {}'.format(query, sort)
+
+ logging.debug('%s (using %s where values)', query, len(where_values))
+ row = self.conn.execute(query, tuple(where_values)).fetchone()
+ if row:
+ return row[0]
+ return None
+
def get_uids(self, scan_id: Optional[int]=None,
show: Optional[Show]=None,
previously_downloaded: Optional[bool]=None,
@@ -344,7 +393,12 @@ class ThumbnailRowsSQL:
logging.debug('%s (%s on %s uids)', query, marked, len(uids))
self.conn.execute(query.format(','.join('?' * len(uids))), [marked] + uids)
- def set_list_marked(self, uids: List[bytes], marked: bool) -> None:
+ def _update_previously_downloaded(self, uids: List[bytes], previously_downloaded: bool) -> None:
+ query = 'UPDATE files SET previously_downloaded=? WHERE uid IN ({})'
+ logging.debug('%s (%s on %s uids)', query, previously_downloaded, len(uids))
+ self.conn.execute(query.format(','.join('?' * len(uids))), [previously_downloaded] + uids)
+
+ def _set_list_values(self, uids: List[bytes], update_value, value) -> None:
if len(uids) == 0:
return
@@ -353,11 +407,20 @@ class ThumbnailRowsSQL:
if len(uids) > 900:
uid_chunks = divide_list_on_length(uids, 900)
for chunk in uid_chunks:
- self._update_marked(chunk, marked)
+ update_value(chunk, value)
else:
- self._update_marked(uids, marked)
+ update_value(uids, value)
self.conn.commit()
+ def set_list_marked(self, uids: List[bytes], marked: bool) -> None:
+ self._set_list_values(uids=uids, update_value=self._update_marked, value=marked)
+
+ def set_list_previously_downloaded(self, uids: List[bytes],
+ previously_downloaded: bool) -> None:
+ self._set_list_values(
+ uids=uids, update_value=self._update_previously_downloaded, value=previously_downloaded
+ )
+
def set_downloaded(self, uid: bytes, downloaded: bool) -> None:
query = 'UPDATE files SET downloaded=? WHERE uid=?'
logging.debug('%s (%s, <uid>)', query, downloaded)
@@ -467,8 +530,10 @@ class ThumbnailRowsSQL:
downloaded: Optional[bool] = None,
scan_id: Optional[int]=None,
exclude_scan_ids: Optional[List[int]] = None) -> Optional[bytes]:
- where, where_values = self._build_where(scan_id=scan_id, downloaded=downloaded,
- file_type=file_type, exclude_scan_ids=exclude_scan_ids)
+ where, where_values = self._build_where(
+ scan_id=scan_id, downloaded=downloaded, file_type=file_type,
+ exclude_scan_ids=exclude_scan_ids
+ )
query = 'SELECT uid FROM files'
if where:
@@ -486,15 +551,37 @@ class ThumbnailRowsSQL:
return row[0]
def any_marked_file_no_job_code(self) -> bool:
- row = self.conn.execute('SELECT uid FROM files WHERE marked=1 AND job_code=0 '
- 'LIMIT 1').fetchone()
+ row = self.conn.execute(
+ 'SELECT uid FROM files WHERE marked=1 AND job_code=0 LIMIT 1'
+ ).fetchone()
+ return row is not None
+
+ def _any_not_previously_downloaded(self, uids: List[bytes]) -> bool:
+ query = 'SELECT uid FROM files WHERE uid IN ({}) AND previously_downloaded=0 LIMIT 1'
+ logging.debug('%s (%s files)', query, len(uids))
+ row = self.conn.execute(query.format(','.join('?' * len(uids))), uids).fetchone()
return row is not None
+ def any_not_previously_downloaded(self, uids: List[bytes]) -> bool:
+ """
+
+ :param uids: list of UIDs to check
+ :return: True if any of the files associated with the UIDs have not been
+ previously downloaded
+ """
+ if len(uids) > 900:
+ uid_chunks = divide_list_on_length(uids, 900)
+ for chunk in uid_chunks:
+ if self._any_not_previously_downloaded(uids=uid_chunks):
+ return True
+ return False
+ else:
+ return self._any_not_previously_downloaded(uids=uids)
+
def _delete_uids(self, uids: List[bytes]) -> None:
query = 'DELETE FROM files WHERE uid IN ({})'
logging.debug('%s (%s files)', query, len(uids))
- self.conn.execute(query.format(
- ','.join('?' * len(uids))), uids)
+ self.conn.execute(query.format(','.join('?' * len(uids))), uids)
def delete_uids(self, uids: List[bytes]) -> None:
"""
@@ -576,6 +663,9 @@ class DownloadedSQL:
PRIMARY KEY (file_name, mtime, size)
)""".format(tn=self.table_name))
+ # Use the character . to for download_name and path to indicate the user manually marked a
+ # file as previously downloaded
+
conn.execute("""CREATE INDEX IF NOT EXISTS download_datetime_idx ON
{tn} (download_name)""".format(tn=self.table_name))
@@ -589,7 +679,9 @@ class DownloadedSQL:
:param name: original filename of photo / video, without path
:param size: file size
:param modification_time: file modification time
- :param download_full_file_name: renamed file including path
+ :param download_full_file_name: renamed file including path,
+ or the character . that the user manually marked the file
+ as previously downloaded
"""
conn = sqlite3.connect(self.db)