diff options
Diffstat (limited to 'raphodo/rpdsql.py')
-rwxr-xr-x | raphodo/rpdsql.py | 132 |
1 files changed, 112 insertions, 20 deletions
diff --git a/raphodo/rpdsql.py b/raphodo/rpdsql.py index f65a02b..112aea0 100755 --- a/raphodo/rpdsql.py +++ b/raphodo/rpdsql.py @@ -148,7 +148,8 @@ class ThumbnailRowsSQL: extensions: Optional[List[str]]=None, proximity_col1: Optional[List[int]]=None, proximity_col2: Optional[List[int]]=None, - exclude_scan_ids: Optional[List[int]]=None) -> Tuple[str, Tuple[Any]]: + exclude_scan_ids: Optional[List[int]]=None, + uids: Optional[List[bytes]]=None,) -> Tuple[str, List[Any]]: where_clauses = [] where_values = [] @@ -187,6 +188,17 @@ class ThumbnailRowsSQL: where_clauses.append('extension IN ({})'.format(','.join('?' * len(extensions)))) where_values.extend(extensions) + if uids is not None: + if len(uids) == 1: + where_clauses.append('uid=?') + where_values.append(uids[0]) + else: + # assume max host parameters in a single SQL statement is 999 + if len(uids) > 900: + uids = uids[:900] + where_clauses.append('uid IN ({})'.format(','.join('?' * len(uids)))) + where_values.extend(uids) + if exclude_scan_ids is not None: if len(exclude_scan_ids) == 1: where_clauses.append(('scan_id!=?')) @@ -216,7 +228,16 @@ class ThumbnailRowsSQL: where_clauses.append('({})'.format(' OR '.join(or_clauses))) where = ' AND '.join(where_clauses) - return (where, where_values) + return where, where_values + + def _build_sort(self, sort_by: Sort, sort_order: Qt.SortOrder) -> str: + if sort_by == Sort.modification_time: + sort = 'ORDER BY mtime {}'.format(self.sort_order_map[sort_order]) + else: + sort = 'ORDER BY {0} {1}, mtime {1}'.format( + self.sort_map[sort_by], self.sort_order_map[sort_order] + ) + return sort def get_view(self, sort_by: Sort, sort_order: Qt.SortOrder, @@ -224,15 +245,11 @@ class ThumbnailRowsSQL: proximity_col1: Optional[List[int]] = None, proximity_col2: Optional[List[int]] = None) -> List[Tuple[bytes, bool]]: - where, where_values = self._build_where(show=show, - proximity_col1=proximity_col1, - proximity_col2=proximity_col2) + where, where_values = self._build_where( + show=show, proximity_col1=proximity_col1, proximity_col2=proximity_col2 + ) - if sort_by == Sort.modification_time: - sort = 'ORDER BY mtime {}'.format(self.sort_order_map[sort_order]) - else: - sort = 'ORDER BY {0} {1}, mtime {1}'.format(self.sort_map[sort_by], - self.sort_order_map[sort_order]) + sort = self._build_sort(sort_by, sort_order) query = 'SELECT uid, marked FROM files' @@ -251,6 +268,38 @@ class ThumbnailRowsSQL: logging.debug('%s', query) return self.conn.execute(query).fetchall() + def get_first_uid_from_uid_list(self, sort_by: Sort, + sort_order: Qt.SortOrder, + show: Show, + uids: List[bytes], + proximity_col1: Optional[List[int]] = None, + proximity_col2: Optional[List[int]] = None) -> Optional[bytes]: + """ + Given a list of uids, and sort and filtering criteria, return the first + uid that the user will have displayed -- if any are displayed. + """ + + where, where_values = self._build_where( + show=show, proximity_col1=proximity_col1, proximity_col2=proximity_col2, uids=uids + ) + + sort = self._build_sort(sort_by, sort_order) + + query = 'SELECT uid FROM files' + + if sort_by == Sort.device: + query = '{} NATURAL JOIN devices'.format(query) + + query = '{} WHERE {}'.format(query, where) + + query = '{} {}'.format(query, sort) + + logging.debug('%s (using %s where values)', query, len(where_values)) + row = self.conn.execute(query, tuple(where_values)).fetchone() + if row: + return row[0] + return None + def get_uids(self, scan_id: Optional[int]=None, show: Optional[Show]=None, previously_downloaded: Optional[bool]=None, @@ -344,7 +393,12 @@ class ThumbnailRowsSQL: logging.debug('%s (%s on %s uids)', query, marked, len(uids)) self.conn.execute(query.format(','.join('?' * len(uids))), [marked] + uids) - def set_list_marked(self, uids: List[bytes], marked: bool) -> None: + def _update_previously_downloaded(self, uids: List[bytes], previously_downloaded: bool) -> None: + query = 'UPDATE files SET previously_downloaded=? WHERE uid IN ({})' + logging.debug('%s (%s on %s uids)', query, previously_downloaded, len(uids)) + self.conn.execute(query.format(','.join('?' * len(uids))), [previously_downloaded] + uids) + + def _set_list_values(self, uids: List[bytes], update_value, value) -> None: if len(uids) == 0: return @@ -353,11 +407,20 @@ class ThumbnailRowsSQL: if len(uids) > 900: uid_chunks = divide_list_on_length(uids, 900) for chunk in uid_chunks: - self._update_marked(chunk, marked) + update_value(chunk, value) else: - self._update_marked(uids, marked) + update_value(uids, value) self.conn.commit() + def set_list_marked(self, uids: List[bytes], marked: bool) -> None: + self._set_list_values(uids=uids, update_value=self._update_marked, value=marked) + + def set_list_previously_downloaded(self, uids: List[bytes], + previously_downloaded: bool) -> None: + self._set_list_values( + uids=uids, update_value=self._update_previously_downloaded, value=previously_downloaded + ) + def set_downloaded(self, uid: bytes, downloaded: bool) -> None: query = 'UPDATE files SET downloaded=? WHERE uid=?' logging.debug('%s (%s, <uid>)', query, downloaded) @@ -467,8 +530,10 @@ class ThumbnailRowsSQL: downloaded: Optional[bool] = None, scan_id: Optional[int]=None, exclude_scan_ids: Optional[List[int]] = None) -> Optional[bytes]: - where, where_values = self._build_where(scan_id=scan_id, downloaded=downloaded, - file_type=file_type, exclude_scan_ids=exclude_scan_ids) + where, where_values = self._build_where( + scan_id=scan_id, downloaded=downloaded, file_type=file_type, + exclude_scan_ids=exclude_scan_ids + ) query = 'SELECT uid FROM files' if where: @@ -486,15 +551,37 @@ class ThumbnailRowsSQL: return row[0] def any_marked_file_no_job_code(self) -> bool: - row = self.conn.execute('SELECT uid FROM files WHERE marked=1 AND job_code=0 ' - 'LIMIT 1').fetchone() + row = self.conn.execute( + 'SELECT uid FROM files WHERE marked=1 AND job_code=0 LIMIT 1' + ).fetchone() + return row is not None + + def _any_not_previously_downloaded(self, uids: List[bytes]) -> bool: + query = 'SELECT uid FROM files WHERE uid IN ({}) AND previously_downloaded=0 LIMIT 1' + logging.debug('%s (%s files)', query, len(uids)) + row = self.conn.execute(query.format(','.join('?' * len(uids))), uids).fetchone() return row is not None + def any_not_previously_downloaded(self, uids: List[bytes]) -> bool: + """ + + :param uids: list of UIDs to check + :return: True if any of the files associated with the UIDs have not been + previously downloaded + """ + if len(uids) > 900: + uid_chunks = divide_list_on_length(uids, 900) + for chunk in uid_chunks: + if self._any_not_previously_downloaded(uids=uid_chunks): + return True + return False + else: + return self._any_not_previously_downloaded(uids=uids) + def _delete_uids(self, uids: List[bytes]) -> None: query = 'DELETE FROM files WHERE uid IN ({})' logging.debug('%s (%s files)', query, len(uids)) - self.conn.execute(query.format( - ','.join('?' * len(uids))), uids) + self.conn.execute(query.format(','.join('?' * len(uids))), uids) def delete_uids(self, uids: List[bytes]) -> None: """ @@ -576,6 +663,9 @@ class DownloadedSQL: PRIMARY KEY (file_name, mtime, size) )""".format(tn=self.table_name)) + # Use the character . to for download_name and path to indicate the user manually marked a + # file as previously downloaded + conn.execute("""CREATE INDEX IF NOT EXISTS download_datetime_idx ON {tn} (download_name)""".format(tn=self.table_name)) @@ -589,7 +679,9 @@ class DownloadedSQL: :param name: original filename of photo / video, without path :param size: file size :param modification_time: file modification time - :param download_full_file_name: renamed file including path + :param download_full_file_name: renamed file including path, + or the character . that the user manually marked the file + as previously downloaded """ conn = sqlite3.connect(self.db) |