# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for the gcloud dataproc tool."""


import base64
import hashlib
import json
import os
import subprocess
import tempfile
import time
import uuid
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import storage_helpers
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import requests
from googlecloudsdk.core.console import console_attr
from googlecloudsdk.core.console import console_io
from googlecloudsdk.core.console import progress_tracker
from googlecloudsdk.core.credentials import store as c_store
from googlecloudsdk.core.util import retry

import six

SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'schemas')


def FormatRpcError(error):
  """Returns a printable representation of a failed Google API's status.proto.

  Args:
    error: the failed Status to print.

  Returns:
    A ready-to-print string representation of the error.
  """
  log.debug('Error:\n' + encoding.MessageToJson(error))
  return error.message


def WaitForResourceDeletion(request_method,
                            resource_ref,
                            message,
                            timeout_s=60,
                            poll_period_s=5):
  """Poll Dataproc resource until it no longer exists."""
  with progress_tracker.ProgressTracker(message, autotick=True):
    start_time = time.time()
    while timeout_s > (time.time() - start_time):
      try:
        request_method(resource_ref)
      except apitools_exceptions.HttpNotFoundError:
        # Object deleted
        return
      except apitools_exceptions.HttpError as error:
        log.debug('Get request for [{0}] failed:\n{1}', resource_ref, error)

        # Do not retry on 4xx errors
        if IsClientHttpException(error):
          raise
      time.sleep(poll_period_s)
  raise exceptions.OperationTimeoutError(
      'Deleting resource [{0}] timed out.'.format(resource_ref))


def GetUniqueId():
  return uuid.uuid4().hex


class Bunch(object):
  """Class that converts a dictionary to javascript like object.

  For example:
      Bunch({'a': {'b': {'c': 0}}}).a.b.c == 0
  """

  def __init__(self, dictionary):
    for key, value in six.iteritems(dictionary):
      if isinstance(value, dict):
        value = Bunch(value)
      self.__dict__[key] = value


def AddJvmDriverFlags(parser):
  parser.add_argument(
      '--jar',
      dest='main_jar',
      help='The HCFS URI of jar file containing the driver jar.')
  parser.add_argument(
      '--class',
      dest='main_class',
      help=('The class containing the main method of the driver. Must be in a'
            ' provided jar or jar that is already on the classpath'))


def IsClientHttpException(http_exception):
  """Returns true if the http exception given is an HTTP 4xx error."""
  return http_exception.status_code >= 400 and http_exception.status_code < 500


# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForOperation(dataproc, operation, message, timeout_s, poll_period_s=5):
  """Poll dataproc Operation until its status is done or timeout reached.

  Args:
    dataproc: wrapper for Dataproc messages, resources, and client
    operation: Operation, message of the operation to be polled.
    message: str, message to display to user while polling.
    timeout_s: number, seconds to poll with retries before timing out.
    poll_period_s: number, delay in seconds between requests.

  Returns:
    Operation: the return value of the last successful operations.get
    request.

  Raises:
    OperationError: if the operation times out or finishes with an error.
  """
  request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
      name=operation.name)
  log.status.Print('Waiting on operation [{0}].'.format(operation.name))
  start_time = time.time()
  warnings_so_far = 0
  is_tty = console_io.IsInteractive(error=True)
  tracker_separator = '\n' if is_tty else ''

  def _LogWarnings(warnings):
    new_warnings = warnings[warnings_so_far:]
    if new_warnings:
      # Drop a line to print nicely with the progress tracker.
      log.err.write(tracker_separator)
      for warning in new_warnings:
        log.warning(warning)

  with progress_tracker.ProgressTracker(message, autotick=True):
    while timeout_s > (time.time() - start_time):
      try:
        operation = dataproc.client.projects_regions_operations.Get(request)
        metadata = ParseOperationJsonMetadata(
            operation.metadata, dataproc.messages.ClusterOperationMetadata)
        _LogWarnings(metadata.warnings)
        warnings_so_far = len(metadata.warnings)
        if operation.done:
          break
      except apitools_exceptions.HttpError as http_exception:
        # Do not retry on 4xx errors.
        if IsClientHttpException(http_exception):
          raise
      time.sleep(poll_period_s)
  metadata = ParseOperationJsonMetadata(
      operation.metadata, dataproc.messages.ClusterOperationMetadata)
  _LogWarnings(metadata.warnings)
  if not operation.done:
    raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
        operation.name))
  elif operation.error:
    raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
        operation.name, FormatRpcError(operation.error)))

  log.info('Operation [%s] finished after %.3f seconds', operation.name,
           (time.time() - start_time))
  return operation


def PrintWorkflowMetadata(metadata, status, operations, errors):
  """Print workflow and job status for the running workflow template.

  This method will detect any changes of state in the latest metadata and print
  all the new states in a workflow template.

  For example:
    Workflow template template-name RUNNING
    Creating cluster: Operation ID create-id.
    Job ID job-id-1 RUNNING
    Job ID job-id-1 COMPLETED
    Deleting cluster: Operation ID delete-id.
    Workflow template template-name DONE

  Args:
    metadata: Dataproc WorkflowMetadata message object, contains the latest
      states of a workflow template.
    status: Dictionary, stores all jobs' status in the current workflow
      template, as well as the status of the overarching workflow.
    operations: Dictionary, stores cluster operation status for the workflow
      template.
    errors: Dictionary, stores errors from the current workflow template.
  """
  # Key chosen to avoid collision with job ids, which are at least 3 characters.
  template_key = 'wt'
  if template_key not in status or metadata.state != status[template_key]:
    if metadata.template is not None:
      log.status.Print('WorkflowTemplate [{0}] {1}'.format(
          metadata.template, metadata.state))
    else:
      # Workflows instantiated inline do not store an id in their metadata.
      log.status.Print('WorkflowTemplate {0}'.format(metadata.state))
    status[template_key] = metadata.state
  if metadata.createCluster != operations['createCluster']:
    if hasattr(metadata.createCluster,
               'error') and metadata.createCluster.error is not None:
      log.status.Print(metadata.createCluster.error)
    elif hasattr(metadata.createCluster,
                 'done') and metadata.createCluster.done is not None:
      log.status.Print('Created cluster: {0}.'.format(metadata.clusterName))
    elif hasattr(
        metadata.createCluster,
        'operationId') and metadata.createCluster.operationId is not None:
      log.status.Print('Creating cluster: Operation ID [{0}].'.format(
          metadata.createCluster.operationId))
    operations['createCluster'] = metadata.createCluster
  if hasattr(metadata.graph, 'nodes'):
    for node in metadata.graph.nodes:
      if not node.jobId:
        continue
      if node.jobId not in status or status[node.jobId] != node.state:
        log.status.Print('Job ID {0} {1}'.format(node.jobId, node.state))
        status[node.jobId] = node.state
      if node.error and (node.jobId not in errors or
                         errors[node.jobId] != node.error):
        log.status.Print('Job ID {0} error: {1}'.format(node.jobId, node.error))
        errors[node.jobId] = node.error
  if metadata.deleteCluster != operations['deleteCluster']:
    if hasattr(metadata.deleteCluster,
               'error') and metadata.deleteCluster.error is not None:
      log.status.Print(metadata.deleteCluster.error)
    elif hasattr(metadata.deleteCluster,
                 'done') and metadata.deleteCluster.done is not None:
      log.status.Print('Deleted cluster: {0}.'.format(metadata.clusterName))
    elif hasattr(
        metadata.deleteCluster,
        'operationId') and metadata.deleteCluster.operationId is not None:
      log.status.Print('Deleting cluster: Operation ID [{0}].'.format(
          metadata.deleteCluster.operationId))
    operations['deleteCluster'] = metadata.deleteCluster


# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForWorkflowTemplateOperation(dataproc,
                                     operation,
                                     timeout_s=None,
                                     poll_period_s=5):
  """Poll dataproc Operation until its status is done or timeout reached.

  Args:
    dataproc: wrapper for Dataproc messages, resources, and client
    operation: Operation, message of the operation to be polled.
    timeout_s: number, seconds to poll with retries before timing out.
    poll_period_s: number, delay in seconds between requests.

  Returns:
    Operation: the return value of the last successful operations.get
    request.

  Raises:
    OperationError: if the operation times out or finishes with an error.
  """
  request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
      name=operation.name)
  log.status.Print('Waiting on operation [{0}].'.format(operation.name))
  start_time = time.time()
  operations = {'createCluster': None, 'deleteCluster': None}
  status = {}
  errors = {}

  # If no timeout is specified, poll forever.
  while timeout_s is None or timeout_s > (time.time() - start_time):
    try:
      operation = dataproc.client.projects_regions_operations.Get(request)
      metadata = ParseOperationJsonMetadata(operation.metadata,
                                            dataproc.messages.WorkflowMetadata)

      PrintWorkflowMetadata(metadata, status, operations, errors)
      if operation.done:
        break
    except apitools_exceptions.HttpError as http_exception:
      # Do not retry on 4xx errors.
      if IsClientHttpException(http_exception):
        raise
    time.sleep(poll_period_s)
  metadata = ParseOperationJsonMetadata(operation.metadata,
                                        dataproc.messages.WorkflowMetadata)

  if not operation.done:
    raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
        operation.name))
  elif operation.error:
    raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
        operation.name, FormatRpcError(operation.error)))
  for op in ['createCluster', 'deleteCluster']:
    if op in operations and operations[op] is not None and operations[op].error:
      raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
          operations[op].operationId, operations[op].error))

  log.info('Operation [%s] finished after %.3f seconds', operation.name,
           (time.time() - start_time))
  return operation


class NoOpProgressDisplay(object):
  """For use in place of a ProgressTracker in a 'with' block."""

  def __enter__(self):
    pass

  def __exit__(self, *unused_args):
    pass


def WaitForJobTermination(dataproc,
                          job,
                          job_ref,
                          message,
                          goal_state,
                          error_state=None,
                          stream_driver_log=False,
                          log_poll_period_s=1,
                          dataproc_poll_period_s=10,
                          timeout_s=None):
  """Poll dataproc Job until its status is terminal or timeout reached.

  Args:
    dataproc: wrapper for dataproc resources, client and messages
    job: The job to wait to finish.
    job_ref: Parsed dataproc.projects.regions.jobs resource containing a
      projectId, region, and jobId.
    message: str, message to display to user while polling.
    goal_state: JobStatus.StateValueValuesEnum, the state to define success
    error_state: JobStatus.StateValueValuesEnum, the state to define failure
    stream_driver_log: bool, Whether to show the Job's driver's output.
    log_poll_period_s: number, delay in seconds between checking on the log.
    dataproc_poll_period_s: number, delay in seconds between requests to the
      Dataproc API.
    timeout_s: number, time out for job completion. None means no timeout.

  Returns:
    Job: the return value of the last successful jobs.get request.

  Raises:
    JobError: if the job finishes with an error.
  """
  request = dataproc.messages.DataprocProjectsRegionsJobsGetRequest(
      projectId=job_ref.projectId, region=job_ref.region, jobId=job_ref.jobId)
  driver_log_stream = None
  last_job_poll_time = 0
  job_complete = False
  wait_display = None
  driver_output_uri = None

  def ReadDriverLogIfPresent():
    if driver_log_stream and driver_log_stream.open:
      # TODO(b/36049794): Don't read all output.
      driver_log_stream.ReadIntoWritable(log.err)

  def PrintEqualsLine():
    attr = console_attr.GetConsoleAttr()
    log.err.Print('=' * attr.GetTermSize()[0])

  if stream_driver_log:
    log.status.Print('Waiting for job output...')
    wait_display = NoOpProgressDisplay()
  else:
    wait_display = progress_tracker.ProgressTracker(message, autotick=True)
  start_time = now = time.time()
  with wait_display:
    while not timeout_s or timeout_s > (now - start_time):
      # Poll logs first to see if it closed.
      ReadDriverLogIfPresent()
      log_stream_closed = driver_log_stream and not driver_log_stream.open
      if (not job_complete and
          job.status.state in dataproc.terminal_job_states):
        job_complete = True
        # Wait an 10s to get trailing output.
        timeout_s = now - start_time + 10

      if job_complete and (not stream_driver_log or log_stream_closed):
        # Nothing left to wait for
        break

      regular_job_poll = (
          not job_complete
          # Poll less frequently on dataproc API
          and now >= last_job_poll_time + dataproc_poll_period_s)
      # Poll at regular frequency before output has streamed and after it has
      # finished.
      expecting_output_stream = stream_driver_log and not driver_log_stream
      expecting_job_done = not job_complete and log_stream_closed
      if regular_job_poll or expecting_output_stream or expecting_job_done:
        last_job_poll_time = now
        try:
          job = dataproc.client.projects_regions_jobs.Get(request)
        except apitools_exceptions.HttpError as error:
          log.warning('GetJob failed:\n{}'.format(six.text_type(error)))
          # Do not retry on 4xx errors.
          if IsClientHttpException(error):
            raise
        if (stream_driver_log and job.driverOutputResourceUri and
            job.driverOutputResourceUri != driver_output_uri):
          if driver_output_uri:
            PrintEqualsLine()
            log.warning("Job attempt failed. Streaming new attempt's output.")
            PrintEqualsLine()
          driver_output_uri = job.driverOutputResourceUri
          driver_log_stream = storage_helpers.StorageObjectSeriesStream(
              job.driverOutputResourceUri)
      time.sleep(log_poll_period_s)
      now = time.time()

  state = job.status.state

  # goal_state and error_state will always be terminal
  if state in dataproc.terminal_job_states:
    if stream_driver_log:
      if not driver_log_stream:
        log.warning('Expected job output not found.')
      elif driver_log_stream.open:
        log.warning('Job terminated, but output did not finish streaming.')
    if state is goal_state:
      return job
    if error_state and state is error_state:
      if job.status.details:
        raise exceptions.JobError('Job [{0}] failed with error:\n{1}'.format(
            job_ref.jobId, job.status.details))
      raise exceptions.JobError('Job [{0}] failed.'.format(job_ref.jobId))
    if job.status.details:
      log.info('Details:\n' + job.status.details)
    raise exceptions.JobError(
        'Job [{0}] entered state [{1}] while waiting for [{2}].'.format(
            job_ref.jobId, state, goal_state))
  raise exceptions.JobTimeoutError(
      'Job [{0}] timed out while in state [{1}].'.format(job_ref.jobId, state))


# This replicates the fallthrough logic of flags._RegionAttributeConfig.
# It is necessary in cases like the --region flag where we are not parsing
# ResourceSpecs
def ResolveRegion():
  return properties.VALUES.dataproc.region.GetOrFail()


# This replicates the fallthrough logic of flags._LocationAttributeConfig.
# It is necessary in cases like the --location flag where we are not parsing
# ResourceSpecs
def ResolveLocation():
  return properties.VALUES.dataproc.location.GetOrFail()


# You probably want to use flags.AddClusterResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseCluster(name, dataproc):
  ref = dataproc.resources.Parse(
      name,
      params={
          'region': ResolveRegion,
          'projectId': properties.VALUES.core.project.GetOrFail
      },
      collection='dataproc.projects.regions.clusters')
  return ref


# You probably want to use flags.AddJobResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseJob(job_id, dataproc):
  ref = dataproc.resources.Parse(
      job_id,
      params={
          'region': ResolveRegion,
          'projectId': properties.VALUES.core.project.GetOrFail
      },
      collection='dataproc.projects.regions.jobs')
  return ref


def ParseOperationJsonMetadata(metadata_value, metadata_type):
  """Returns an Operation message for a metadata value."""
  if not metadata_value:
    return metadata_type()
  return encoding.JsonToMessage(metadata_type,
                                encoding.MessageToJson(metadata_value))


# Used in bizarre scenarios where we want a qualified region rather than a
# short name
def ParseRegion(dataproc):
  ref = dataproc.resources.Parse(
      None,
      params={
          'regionId': ResolveRegion,
          'projectId': properties.VALUES.core.project.GetOrFail
      },
      collection='dataproc.projects.regions')
  return ref


# Get dataproc.projects.locations resource
def ParseProjectsLocations(dataproc):
  ref = dataproc.resources.Parse(
      None,
      params={
          'locationsId': ResolveRegion,
          'projectsId': properties.VALUES.core.project.GetOrFail
      },
      collection='dataproc.projects.locations')
  return ref


# Get dataproc.projects.locations resource
# This can be merged with ParseProjectsLocations() once we have migrated batches
# from `region` to `location`.
def ParseProjectsLocationsForSession(dataproc):
  ref = dataproc.resources.Parse(
      None,
      params={
          'locationsId': ResolveLocation(),
          'projectsId': properties.VALUES.core.project.GetOrFail
      },
      collection='dataproc.projects.locations')
  return ref


def ReadAutoscalingPolicy(dataproc, policy_id, policy_file_name=None):
  """Returns autoscaling policy read from YAML file.

  Args:
    dataproc: wrapper for dataproc resources, client and messages.
    policy_id: The autoscaling policy id (last piece of the resource name).
    policy_file_name: if set, location of the YAML file to read from. Otherwise,
      reads from stdin.

  Raises:
    argparse.ArgumentError if duration formats are invalid or out of bounds.
  """
  data = console_io.ReadFromFileOrStdin(policy_file_name or '-', binary=False)
  policy = export_util.Import(
      message_type=dataproc.messages.AutoscalingPolicy, stream=data)

  # Ignore user set id in the file (if any), and overwrite with the policy_ref
  # provided with this command
  policy.id = policy_id

  # Similarly, ignore the set resource name. This field is OUTPUT_ONLY, so we
  # can just clear it.
  policy.name = None

  # Set duration fields to their seconds values
  if policy.basicAlgorithm is not None:
    if policy.basicAlgorithm.cooldownPeriod is not None:
      policy.basicAlgorithm.cooldownPeriod = str(
          arg_parsers.Duration(lower_bound='2m', upper_bound='1d')(
              policy.basicAlgorithm.cooldownPeriod)) + 's'
    if policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout is not None:
      policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout = str(
          arg_parsers.Duration(lower_bound='0s', upper_bound='1d')
          (policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout)) + 's'

  return policy


def CreateAutoscalingPolicy(dataproc, name, policy):
  """Returns the server-resolved policy after creating the given policy.

  Args:
    dataproc: wrapper for dataproc resources, client and messages.
    name: The autoscaling policy resource name.
    policy: The AutoscalingPolicy message to create.
  """
  # TODO(b/109837200) make the dataproc discovery doc parameters consistent
  # Parent() fails for the collection because of projectId/projectsId and
  # regionId/regionsId inconsistencies.
  # parent = template_ref.Parent().RelativePath()
  parent = '/'.join(name.split('/')[0:4])

  request = \
    dataproc.messages.DataprocProjectsRegionsAutoscalingPoliciesCreateRequest(
        parent=parent,
        autoscalingPolicy=policy)
  policy = dataproc.client.projects_regions_autoscalingPolicies.Create(request)
  log.status.Print('Created [{0}].'.format(policy.id))
  return policy


def UpdateAutoscalingPolicy(dataproc, name, policy):
  """Returns the server-resolved policy after updating the given policy.

  Args:
    dataproc: wrapper for dataproc resources, client and messages.
    name: The autoscaling policy resource name.
    policy: The AutoscalingPolicy message to create.
  """
  # Though the name field is OUTPUT_ONLY in the API, the Update() method of the
  # gcloud generated dataproc client expects it to be set.
  policy.name = name

  policy = \
    dataproc.client.projects_regions_autoscalingPolicies.Update(policy)
  log.status.Print('Updated [{0}].'.format(policy.id))
  return policy


def _DownscopeCredentials(token, access_boundary_json):
  """Downscope the given credentials to the given access boundary.

  Args:
    token: The credentials to downscope.
    access_boundary_json: The JSON-formatted access boundary.

  Returns:
    A downscopded credential with the given access-boundary.
  """
  payload = {
      'grant_type': 'urn:ietf:params:oauth:grant-type:token-exchange',
      'requested_token_type': 'urn:ietf:params:oauth:token-type:access_token',
      'subject_token_type': 'urn:ietf:params:oauth:token-type:access_token',
      'subject_token': token,
      'options': access_boundary_json
  }
  universe_domain = properties.VALUES.core.universe_domain.Get()
  cab_token_url = f'https://sts.{universe_domain}/v1/token'
  if properties.VALUES.context_aware.use_client_certificate.GetBool():
    cab_token_url = f'https://sts.mtls.{universe_domain}/v1/token'
  headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  downscope_response = requests.GetSession().post(
      cab_token_url, headers=headers, data=payload)
  if downscope_response.status_code != 200:
    raise ValueError('Error downscoping credentials')
  cab_token = json.loads(downscope_response.content)
  return cab_token.get('access_token', None)


def GetCredentials(access_boundary_json):
  """Get an access token for the user's current credentials.

  Args:
    access_boundary_json: JSON string holding the definition of the access
      boundary to apply to the credentials.

  Raises:
    PersonalAuthError: If no access token could be fetched for the user.

  Returns:
    An access token for the user.
  """
  cred = c_store.Load(
      None, allow_account_impersonation=True)
  c_store.Refresh(cred)

  token = cred.token
  if not token:
    raise exceptions.PersonalAuthError(
        'No access token could be obtained from the current credentials.')
  return _DownscopeCredentials(token, access_boundary_json)


class PersonalAuthUtils(object):
  """Util functions for enabling personal auth session."""

  def __init__(self):
    pass

  def _RunOpensslCommand(self, openssl_executable, args, stdin=None):
    """Run the specified command, capturing and returning output as appropriate.

    Args:
      openssl_executable: The path to the openssl executable.
      args: The arguments to the openssl command to run.
      stdin: The input to the command.

    Returns:
      The output of the command.

    Raises:
      PersonalAuthError: If the call to openssl fails
    """
    command = [openssl_executable]
    command.extend(args)
    stderr = None
    try:
      if getattr(subprocess, 'run', None):
        proc = subprocess.run(
            command,
            input=stdin,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            check=False)
        stderr = proc.stderr.decode('utf-8').strip()
        # N.B. It would be better if we could simply call `subprocess.run` with
        # the `check` keyword arg set to true rather than manually calling
        # `check_returncode`. However, we want to capture the stderr when the
        # command fails, and the CalledProcessError type did not have a field
        # for the stderr until Python version 3.5.
        #
        # As such, we need to manually call `check_returncode` as long as we
        # are supporting Python versions prior to 3.5.
        proc.check_returncode()
        return proc.stdout
      else:
        p = subprocess.Popen(
            command,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE)
        stdout, _ = p.communicate(input=stdin)
        return stdout
    except Exception as ex:
      if stderr:
        log.error('OpenSSL command "%s" failed with error message "%s"',
                  ' '.join(command), stderr)
      raise exceptions.PersonalAuthError('Failure running openssl command: "' +
                                         ' '.join(command) + '": ' +
                                         six.text_type(ex))

  def _ComputeHmac(self, key, data, openssl_executable):
    """Compute HMAC tag using OpenSSL."""
    cmd_output = self._RunOpensslCommand(
        openssl_executable, ['dgst', '-sha256', '-hmac', key],
        stdin=data).decode('utf-8')
    try:
      # Split the openssl output to get the HMAC.
      stripped_output = cmd_output.strip().split(' ')[1]
      if len(stripped_output) != 64:
        raise ValueError('HMAC output is expected to be 64 characters long.')
      int(stripped_output, 16)  # Check that the HMAC is in hex format.
    except Exception as ex:
      raise exceptions.PersonalAuthError(
          'Failure due to invalid openssl output: ' + six.text_type(ex))
    return (stripped_output + '\n').encode('utf-8')

  def _DeriveHkdfKey(self, prk, info, openssl_executable):
    """Derives HMAC-based Key Derivation Function (HKDF) key through expansion on the initial pseudorandom key.

    Args:
      prk: a pseudorandom key.
      info: optional context and application specific information (can be
        empty).
      openssl_executable: The path to the openssl executable.

    Returns:
      Output keying material, expected to be of 256-bit length.
    """
    if len(prk) != 32:
      raise ValueError(
          'The given initial pseudorandom key is expected to be 32 bytes long.')
    base16_prk = base64.b16encode(prk).decode('utf-8')
    t1 = self._ComputeHmac(base16_prk, b'', openssl_executable)
    t2data = bytearray(t1)
    t2data.extend(info)
    t2data.extend(b'\x01')
    return self._ComputeHmac(base16_prk, t2data, openssl_executable)

  # It is possible (although very rare) for the random pad generated by
  # openssl to not be usable by openssl for encrypting the secret. When
  # that happens the call to openssl will raise a CalledProcessError with
  # the message "Error reading password from BIO\nError getting password".
  #
  # To account for this we retry on that error, but this is so rare that
  # a single retry should be sufficient.
  @retry.RetryOnException(max_retrials=1)
  def _EncodeTokenUsingOpenssl(self, public_key, secret, openssl_executable):
    """Encode token using OpenSSL.

    Args:
      public_key: The public key for the session/cluster.
      secret: Token to be encrypted.
      openssl_executable: The path to the openssl executable.

    Returns:
      Encrypted token.
    """
    key_hash = hashlib.sha256((public_key + '\n').encode('utf-8')).hexdigest()
    iv_bytes = base64.b16encode(os.urandom(16))
    initialization_vector = iv_bytes.decode('utf-8')
    initial_key = os.urandom(32)
    encryption_key = self._DeriveHkdfKey(initial_key,
                                         'encryption_key'.encode('utf-8'),
                                         openssl_executable)
    auth_key = base64.b16encode(
        self._DeriveHkdfKey(initial_key, 'auth_key'.encode('utf-8'),
                            openssl_executable)).decode('utf-8')
    with tempfile.NamedTemporaryFile() as kf:
      kf.write(public_key.encode('utf-8'))
      kf.seek(0)
      encrypted_key = self._RunOpensslCommand(
          openssl_executable,
          ['rsautl', '-oaep', '-encrypt', '-pubin', '-inkey', kf.name],
          stdin=base64.b64encode(initial_key))
    if len(encrypted_key) != 512:
      raise ValueError('The encrypted key is expected to be 512 bytes long.')
    encoded_key = base64.b64encode(encrypted_key).decode('utf-8')

    with tempfile.NamedTemporaryFile() as pf:
      pf.write(encryption_key)
      pf.seek(0)
      encrypt_args = [
          'enc', '-aes-256-ctr', '-salt', '-iv', initialization_vector, '-pass',
          'file:{}'.format(pf.name)
      ]
      encrypted_token = self._RunOpensslCommand(
          openssl_executable, encrypt_args, stdin=secret.encode('utf-8'))
    if len(encrypted_key) != 512:
      raise ValueError('The encrypted key is expected to be 512 bytes long.')
    encoded_token = base64.b64encode(encrypted_token).decode('utf-8')

    hmac_input = bytearray(iv_bytes)
    hmac_input.extend(encrypted_token)
    hmac_tag = self._ComputeHmac(auth_key, hmac_input,
                                 openssl_executable).decode('utf-8')[
                                     0:32]  # Truncate the HMAC tag to 128-bit
    return '{}:{}:{}:{}:{}'.format(key_hash, encoded_token, encoded_key,
                                   initialization_vector, hmac_tag)

  def EncryptWithPublicKey(self, public_key, secret, openssl_executable):
    """Encrypt secret with resource public key.

    Args:
      public_key: The public key for the session/cluster.
      secret: Token to be encrypted.
      openssl_executable: The path to the openssl executable.

    Returns:
      Encrypted token.
    """
    if openssl_executable:
      return self._EncodeTokenUsingOpenssl(public_key, secret,
                                           openssl_executable)
    try:
      # pylint: disable=g-import-not-at-top
      import tink
      from tink import hybrid
      # pylint: enable=g-import-not-at-top
    except ImportError:
      raise exceptions.PersonalAuthError(
          'Cannot load the Tink cryptography library. Either the '
          'library is not installed, or site packages are not '
          'enabled for the Google Cloud SDK. Please consult Cloud '
          'Dataproc Personal Auth documentation on adding Tink to '
          'Google Cloud SDK for further instructions.\n'
          'https://cloud.google.com/dataproc/docs/concepts/iam/personal-auth')
    hybrid.register()
    context = b''

    # Extract value of key corresponding to primary key.
    public_key_value = json.loads(public_key)['key'][0]['keyData']['value']
    key_hash = hashlib.sha256(
        (public_key_value + '\n').encode('utf-8')).hexdigest()

    # Load public key and create keyset handle.
    reader = tink.JsonKeysetReader(public_key)
    kh_pub = tink.read_no_secret_keyset_handle(reader)

    # Create encrypter instance.
    encrypter = kh_pub.primitive(hybrid.HybridEncrypt)
    ciphertext = encrypter.encrypt(secret.encode('utf-8'), context)

    encoded_token = base64.b64encode(ciphertext).decode('utf-8')
    return '{}:{}'.format(key_hash, encoded_token)

  def IsTinkLibraryInstalled(self):
    """Check if Tink cryptography library can be loaded."""
    try:
      # pylint: disable=g-import-not-at-top
      # pylint: disable=unused-import
      import tink
      from tink import hybrid
      # pylint: enable=g-import-not-at-top
      # pylint: enable=unused-import
      return True
    except ImportError:
      return False


def ReadSessionTemplate(dataproc, template_file_name=None):
  """Returns session template read from YAML file.

  Args:
    dataproc: Wrapper for dataproc resources, client and messages.
    template_file_name: If set, location of the YAML file to read from.
      Otherwise, reads from stdin.

  Raises:
    argparse.ArgumentError if duration formats are invalid or out of bounds.
  """
  data = console_io.ReadFromFileOrStdin(template_file_name or '-', binary=False)
  template = export_util.Import(
      message_type=dataproc.messages.SessionTemplate, stream=data)

  return template


def CreateSessionTemplate(dataproc, name, template):
  """Returns the server-resolved template after creating the given template.

  Args:
    dataproc: Wrapper for dataproc resources, client and messages.
    name: The session template resource name.
    template: The SessionTemplate message to create.
  """
  parent = '/'.join(name.split('/')[0:4])
  template.name = name

  request = (
      dataproc.messages.DataprocProjectsLocationsSessionTemplatesCreateRequest(
          parent=parent,
          sessionTemplate=template))
  template = dataproc.client.projects_locations_sessionTemplates.Create(request)
  log.status.Print('Created [{0}].'.format(template.name))
  return template


def UpdateSessionTemplate(dataproc, name, template):
  """Returns the server-resolved template after updating the given template.

  Args:
    dataproc: Wrapper for dataproc resources, client and messages.
    name: The session template resource name.
    template: The SessionTemplate message to create.
  """
  template.name = name

  template = dataproc.client.projects_locations_sessionTemplates.Patch(template)
  log.status.Print('Updated [{0}].'.format(template.name))
  return template


def YieldFromListWithUnreachableList(unreachable_warning_msg, *args, **kwargs):
  """Yields from paged List calls handling unreachable list."""
  unreachable = set()

  def _GetFieldFn(message, attr):
    unreachable.update(message.unreachable)
    return getattr(message, attr)

  result = list_pager.YieldFromList(get_field_func=_GetFieldFn, *args, **kwargs)
  for item in result:
    yield item
  if unreachable:
    log.warning(
        unreachable_warning_msg,
        ', '.join(sorted(unreachable)),
    )
