Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: distinguish server timeouts from transport timeouts #43

Merged
merged 6 commits into from
Mar 9, 2020
37 changes: 13 additions & 24 deletions google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
except ImportError: # Python 2.7
import collections as collections_abc

import concurrent.futures
import copy
import functools
import gzip
Expand All @@ -48,7 +47,6 @@
import google.api_core.client_options
import google.api_core.exceptions
from google.api_core import page_iterator
from google.auth.transport.requests import TimeoutGuard
import google.cloud._helpers
from google.cloud import exceptions
from google.cloud.client import ClientWithProject
Expand Down Expand Up @@ -2598,27 +2596,22 @@ def list_partitions(self, table, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.

Returns:
List[str]:
A list of the partition ids present in the partitioned table
"""
table = _table_arg_to_table_ref(table, default_project=self.project)

with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
meta_table = self.get_table(
TableReference(
DatasetReference(table.project, table.dataset_id),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)
timeout = guard.remaining_timeout
meta_table = self.get_table(
TableReference(
DatasetReference(table.project, table.dataset_id),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)

subset = [col for col in meta_table.schema if col.name == "partition_id"]
return [
Expand Down Expand Up @@ -2685,8 +2678,8 @@ def list_rows(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.

Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -2711,11 +2704,7 @@ def list_rows(
# No schema, but no selected_fields. Assume the developer wants all
# columns, so get the table resource for them rather than failing.
elif len(schema) == 0:
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
table = self.get_table(table.reference, retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
table = self.get_table(table.reference, retry=retry, timeout=timeout)
schema = table.schema

params = {}
Expand Down
62 changes: 19 additions & 43 deletions google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from six.moves import http_client

import google.api_core.future.polling
from google.auth.transport.requests import TimeoutGuard
from google.cloud import exceptions
from google.cloud.exceptions import NotFound
from google.cloud.bigquery.dataset import Dataset
Expand Down Expand Up @@ -55,7 +54,6 @@
_DONE_STATE = "DONE"
_STOPPED_REASON = "stopped"
_TIMEOUT_BUFFER_SECS = 0.1
_SERVER_TIMEOUT_MARGIN_SECS = 1.0
_CONTAINS_ORDER_BY = re.compile(r"ORDER\s+BY", re.IGNORECASE)

_ERROR_REASON_TO_EXCEPTION = {
Expand Down Expand Up @@ -796,8 +794,8 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.

Returns:
_AsyncJob: This instance.
Expand All @@ -809,11 +807,7 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
if the job did not complete in the given timeout.
"""
if self.state is None:
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
self._begin(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
self._begin(retry=retry, timeout=timeout)
# TODO: modify PollingFuture so it can pass a retry argument to done().
return super(_AsyncJob, self).result(timeout=timeout)

Expand Down Expand Up @@ -2602,6 +2596,7 @@ def __init__(self, job_id, query, client, job_config=None):
self._configuration = job_config
self._query_results = None
self._done_timeout = None
self._transport_timeout = None

@property
def allow_large_results(self):
Expand Down Expand Up @@ -3059,19 +3054,9 @@ def done(self, retry=DEFAULT_RETRY, timeout=None):
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)

# If the server-side processing timeout (timeout_ms) is specified and
# would be picked as the total request timeout, we want to add a small
# margin to it - we don't want to timeout the connection just as the
# server-side processing might have completed, but instead slightly
# after the server-side deadline.
# However, if `timeout` is specified, and is shorter than the adjusted
# server timeout, the former prevails.
if timeout_ms is not None and timeout_ms > 0:
server_timeout_with_margin = timeout_ms / 1000 + _SERVER_TIMEOUT_MARGIN_SECS
if timeout is not None:
timeout = min(server_timeout_with_margin, timeout)
else:
timeout = server_timeout_with_margin
# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

# Do not refresh if the state is already done, as the job will not
# change once complete.
Expand All @@ -3082,19 +3067,20 @@ def done(self, retry=DEFAULT_RETRY, timeout=None):
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=timeout,
timeout=transport_timeout,
)

# Only reload the job once we know the query is complete.
# This will ensure that fields such as the destination table are
# correctly populated.
if self._query_results.complete:
self.reload(retry=retry, timeout=timeout)
self.reload(retry=retry, timeout=transport_timeout)

return self.state == _DONE_STATE

def _blocking_poll(self, timeout=None):
self._done_timeout = timeout
self._transport_timeout = timeout
super(QueryJob, self)._blocking_poll(timeout=timeout)

@staticmethod
Expand Down Expand Up @@ -3170,8 +3156,8 @@ def result(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.

Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -3189,27 +3175,17 @@ def result(
If the job did not complete in the given timeout.
"""
try:
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
)
with guard:
super(QueryJob, self).result(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
super(QueryJob, self).result(retry=retry, timeout=timeout)

# Return an iterator instead of returning the job.
if not self._query_results:
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
location=self.location,
timeout=timeout,
)
with guard:
self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
location=self.location,
timeout=timeout,
)
timeout = guard.remaining_timeout
except exceptions.GoogleCloudError as exc:
exc.message += self._format_for_exception(self.query, self.job_id)
exc.query_job = self
Expand Down
78 changes: 0 additions & 78 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import unittest
import warnings

import freezegun
plamut marked this conversation as resolved.
Show resolved Hide resolved
import mock
import requests
import six
Expand Down Expand Up @@ -5496,43 +5495,6 @@ def test_list_partitions_with_string_id(self):

self.assertEqual(len(partition_list), 0)

def test_list_partitions_splitting_timout_between_requests(self):
from google.cloud.bigquery.table import Table

row_count = 2
meta_info = _make_list_partitons_meta_info(
self.PROJECT, self.DS_ID, self.TABLE_ID, row_count
)

data = {
"totalRows": str(row_count),
"rows": [{"f": [{"v": "20180101"}]}, {"f": [{"v": "20180102"}]}],
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
client._connection = make_connection(meta_info, data)
table = Table(self.TABLE_REF)

with freezegun.freeze_time("2019-01-01 00:00:00", tick=False) as frozen_time:

def delayed_get_table(*args, **kwargs):
frozen_time.tick(delta=1.4)
return orig_get_table(*args, **kwargs)

orig_get_table = client.get_table
client.get_table = mock.Mock(side_effect=delayed_get_table)

client.list_partitions(table, timeout=5.0)

client.get_table.assert_called_once()
_, kwargs = client.get_table.call_args
self.assertEqual(kwargs.get("timeout"), 5.0)

client._connection.api_request.assert_called()
_, kwargs = client._connection.api_request.call_args
self.assertAlmostEqual(kwargs.get("timeout"), 3.6, places=5)

def test_list_rows(self):
import datetime
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -5918,46 +5880,6 @@ def test_list_rows_with_missing_schema(self):
self.assertEqual(rows[1].age, 31, msg=repr(table))
self.assertIsNone(rows[2].age, msg=repr(table))

def test_list_rows_splitting_timout_between_requests(self):
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery.table import Table

response = {"totalRows": "0", "rows": []}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
client._connection = make_connection(response, response)

table = Table(
self.TABLE_REF, schema=[SchemaField("field_x", "INTEGER", mode="NULLABLE")]
)

with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen_time:

def delayed_get_table(*args, **kwargs):
frozen_time.tick(delta=1.4)
return table

client.get_table = mock.Mock(side_effect=delayed_get_table)

rows_iter = client.list_rows(
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
timeout=5.0,
)
six.next(rows_iter.pages)

client.get_table.assert_called_once()
_, kwargs = client.get_table.call_args
self.assertEqual(kwargs.get("timeout"), 5.0)

client._connection.api_request.assert_called_once()
_, kwargs = client._connection.api_request.call_args
self.assertAlmostEqual(kwargs.get("timeout"), 3.6)

def test_list_rows_error(self):
creds = _make_credentials()
http = object()
Expand Down
Loading